SA
Samin Chandeepa
20 min read

Training MedSLM-SFT: Supervised Fine-Tuning for Medical Instruction Following

In this post, we fine-tune MedSLM using LoRA and QLoRA to transform it from a text completion engine into a medical question-answering assistant. We use Unsloth for 2-5x faster training with 70% less memory.

In the previous post, we curated a high-quality medical SFT dataset. Now we use it to instruction-tune MedSLM — teaching it to follow a structured question-answering format using parameter-efficient LoRA adapters.

#Why LoRA Instead of Full Fine-Tuning?

Full fine-tuning updates all ~330M parameters, requiring significant GPU memory and risking catastrophic forgetting of the medical knowledge learned during pre-training. LoRA (Low-Rank Adaptation) freezes the pre-trained weights and injects small, trainable low-rank matrices into specific layers. QLoRA goes further by quantizing the frozen base weights to 4-bit precision, reducing memory even more.

AspectFull Fine-TuningLoRAQLoRA
Trainable parameters100% (~330M)~1-5% (~5-15M)~1-5% (~5-15M)
GPU memory~8-12 GB~4-6 GB~2-4 GB
Training speedBaseline2-3x faster2-5x faster
Catastrophic forgetting riskHighLowLow
Storage per checkpoint~1.3 GB~30-80 MB~30-80 MB

#Pipeline Overview

  1. Model Conversion — Convert custom MedSLM weights to HuggingFace LLaMA-compatible format
  2. Load Model with Unsloth — Load the converted model with optimized CUDA kernels
  3. Load SFT Dataset — Load the instruction dataset from HuggingFace
  4. LoRA Adapter Configuration — Configure and attach LoRA adapters to target modules
  5. Training with SFTTrainer — Fine-tune using TRL's SFTTrainer
  6. Evaluation — Perplexity, sample generation, and quality assessment
  7. Merge LoRA Adapters — Merge trained adapters back into the base model
  8. Save & Upload to HuggingFace — Push the fine-tuned model to the Hub

#Step 1: Weight Conversion to HuggingFace Format

Our custom MedSLM model is architecturally identical to LLaMA — it uses the same building blocks (RMSNorm, RoPE, SwiGLU, GQA). The only difference is the layer naming convention. We map MedSLM weight names to the standard HuggingFace LlamaForCausalLM naming so the model can be used with Unsloth, PEFT, and the entire HF ecosystem.

MedSLM LayerHuggingFace LLaMA LayerDescription
tok_emb.weightmodel.embed_tokens.weightToken embeddings
blocks.{i}.attn.wq.weightmodel.layers.{i}.self_attn.q_proj.weightQuery projection
blocks.{i}.attn.wk.weightmodel.layers.{i}.self_attn.k_proj.weightKey projection
blocks.{i}.attn.wv.weightmodel.layers.{i}.self_attn.v_proj.weightValue projection
blocks.{i}.attn.wo.weightmodel.layers.{i}.self_attn.o_proj.weightOutput projection
blocks.{i}.ffn.w_gate.weightmodel.layers.{i}.mlp.gate_proj.weightSwiGLU gate
blocks.{i}.ffn.w_up.weightmodel.layers.{i}.mlp.up_proj.weightSwiGLU up
blocks.{i}.ffn.w_down.weightmodel.layers.{i}.mlp.down_proj.weightSwiGLU down
norm_f.weightmodel.norm.weightFinal RMSNorm
def convert_medslm_to_llama(medslm_state_dict: dict) -> dict:
    llama_state_dict = {}
    for medslm_name, weight in medslm_state_dict.items():
        if medslm_name in WEIGHT_MAP:
            llama_state_dict[WEIGHT_MAP[medslm_name]] = weight
            continue
        for medslm_suffix, llama_suffix in LAYER_WEIGHT_MAP.items():
            if medslm_suffix in medslm_name:
                layer_idx = medslm_name.split(".")[1]
                llama_name = f"model.layers.{layer_idx}.{llama_suffix}"
                llama_state_dict[llama_name] = weight
                break
    return llama_state_dict

The conversion maps all 218 weight tensors (329,910,272 parameters) with zero skipped. We verify the conversion by running a forward pass through the instantiated LlamaForCausalLM, confirming valid output logits in the range [-10.40, 11.36].

#Step 2: Load with Unsloth

We load the converted model using Unsloth's FastLanguageModel, which provides 2-5x faster training via fused CUDA kernels for attention, RoPE, and cross-entropy, 70% less memory through optimized gradient checkpointing, and automatic 4-bit quantization via bitsandbytes. We set load_in_4bit=True for QLoRA, quantizing the base model weights to 4-bit NF4 format while keeping the LoRA adapters in full precision.

from unsloth import FastLanguageModel

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="medslm_hf_converted",
    max_seq_length=1024,
    dtype=None,           # Auto-detect (bf16 if supported)
    load_in_4bit=True,    # QLoRA: 4-bit quantized base
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Parameters (4-bit)

190.7M

Vocab Size

50,257

Precision

BF16

GPU

RTX 4090 (24 GB)

#Step 3: LoRA Adapter Configuration

LoRA works by freezing the pre-trained model weights and injecting small, trainable low-rank matrices into specific layers. Instead of updating a full weight matrix W (shape d x d), LoRA learns two small matrices A (shape d x r) and B (shape r x d), where r << d. The update becomes W_new = W_frozen + (alpha/r) * B * A. We attach LoRA adapters to all linear layers in the transformer — both attention projections and SwiGLU FFN projections — giving LoRA maximum expressiveness while remaining parameter-efficient.

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0.0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    use_gradient_checkpointing="unsloth",
)
ParameterValueDescription
Rank (r)16Rank of low-rank matrices — higher = more capacity
Alpha32Scaling factor — effective scale = alpha/r = 2.0
Dropout0.0Unsloth-optimized: 0 for fused kernels
Target modulesq, k, v, o, gate, up, downAll linear layers get LoRA adapters
BiasnoneNo bias terms in LoRA layers

Total Parameters

197.8M

Trainable (LoRA)

7.1M (3.59%)

Frozen (Base)

190.7M (96.41%)

Memory Savings

~96%

#Step 4: Training Configuration

We use TRL's SFTTrainer — a specialized trainer for supervised fine-tuning that handles tokenization, sequence packing (combining multiple short examples into single sequences for efficiency), and loss computation. The learning rate of 2e-4 is standard for LoRA SFT — higher than pre-training since we're only updating adapters. Sequence packing is enabled to maximize GPU utilization given our average example length of ~180 tokens against a 1,024-token context window.

HyperparameterValueReasoning
Learning rate2e-4Standard for LoRA SFT
LR schedulerCosineSmooth decay for stable convergence
Warmup ratio5%Gradual ramp-up for stability
Batch size (per device)4Kept small for memory efficiency
Gradient accumulation8Effective batch size = 32
Epochs3Enough to learn the format without overfitting
Weight decay0.01Light regularization
Max gradient norm1.0Gradient clipping for stability
OptimizerAdamW (8-bit)Memory-efficient optimizer
Sequence packingEnabledPack short examples for efficiency
Max sequence length1,024Matches pre-training context window
PrecisionBF16FP16 fallback if not supported
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset["train"],
    eval_dataset=sft_dataset["validation"],
    args=TrainingArguments(
        output_dir="medslm_sft_output",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        weight_decay=0.01,
        bf16=True,
        optim="adamw_8bit",
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=500,
        load_best_model_at_end=True,
    ),
    dataset_text_field="text",
    max_seq_length=1024,
    packing=True,
)

#Training Results

Training completed in approximately 43 minutes on a single NVIDIA RTX 4090. The training loss decreased steadily from 3.06 to 2.47, with validation loss converging to 2.21 — no significant overfitting gap. The base model weights remained frozen throughout; only the 7.1M LoRA adapter parameters were updated.

Total Steps

4,329

Final Train Loss

2.4678

Final Val Loss

2.2084

Runtime

~43 min

Throughput

53.4 samples/sec

GPU

1x RTX 4090
StepTraining LossValidation Loss
2503.06062.9208
5002.78142.6572
1,0002.61462.4686
1,5002.36772.3664
2,0002.29512.3084
2,5002.32002.2627
3,0002.19682.2314
3,5002.28342.2160
4,0002.25652.2089
4,3292.19762.2084
SFT training and validation loss curves

Training and validation loss over 4,329 steps — steady convergence with no significant overfitting

#Evaluation

We evaluate the fine-tuned model from multiple angles: test-set perplexity, sample generation quality, and response length statistics. The model generates coherent medical responses that follow the instruction format, though as a 330M parameter model it is still prone to occasional hallucinations and repetition.

Generation statistics on 50 test examples show a mean response length of 205 words, median of 209 words, with a minimum of 156 words and maximum of 234 words — indicating the model has learned to produce consistently substantive answers rather than short or degenerate outputs.

Sample Generations

Here are example responses from the fine-tuned model to medical questions. While the model demonstrates instruction-following ability and medical vocabulary, the small parameter count means responses sometimes contain factual inaccuracies — reinforcing that this is a research model, not suitable for clinical use.

Q: How is hypertension diagnosed and treated?
A: The first-line treatment for hypertension is a combination of
   medication and lifestyle changes. This means that the patient's
   blood pressure can be improved by taking medications such as
   benzodiazepines or other medications. In some cases, medications
   such as diuretics may also be used to treat hypertension...

Q: What causes acute myocardial infarction?
A: The exact cause of acute myocardial infarction is not fully
   understood, but it is thought to be related to the severity of
   the symptoms. Treatment for acute myocardial infarction may
   involve medications to suppress the immune system and reduce
   the risk of developing myocardial infarction...

#Merging & Export

After training, the LoRA adapters exist as separate lightweight matrices (~17.8 MB). For deployment, we provide two options: keeping adapters separate (for flexibility and easy swapping) and merging adapters into the base model (for standalone deployment without PEFT dependency). The merge operation computes W_merged = W_base + (alpha/r) * B @ A for each adapted layer, producing a single 1.32 GB model file.

# Save LoRA adapters separately (~17.8 MB)
model.save_pretrained("medslm_sft_output/lora_adapters")
tokenizer.save_pretrained("medslm_sft_output/lora_adapters")

# Merge adapters into base model (~1.32 GB)
model.save_pretrained_merged(
    "medslm_sft_output/merged_model",
    tokenizer,
    save_method="merged_16bit",
)
OutputSizeRepository
LoRA adapters17.8 MBSaminx22/MedSLM-SFT-LoRA
Merged model1.32 GBSaminx22/MedSLM-SFT

#Inference

The fine-tuned model can be loaded in two ways: directly as the merged model (simplest approach), or as the base model plus LoRA adapters via PEFT (more flexible, allows swapping adapters). Both methods produce identical results. The prompt must follow the exact instruction template used during training.

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# Method 1: Load merged model
model = AutoModelForCausalLM.from_pretrained(
    "Saminx22/MedSLM-SFT",
    torch_dtype=torch.float16,
    device_map="auto",
)

# Method 2: Load base + LoRA adapters
base_model = AutoModelForCausalLM.from_pretrained(
    "Saminx22/MedSLM",
    torch_dtype=torch.float16,
    device_map="auto",
)
model = PeftModel.from_pretrained(base_model, "Saminx22/MedSLM-SFT-LoRA")
SYSTEM_PROMPT = (
    "You are a medical AI assistant. "
    "Provide accurate, evidence-based answers to medical questions."
)

def ask(question: str, max_new_tokens: int = 300) -> str:
    prompt = (
        f"### System:\n{SYSTEM_PROMPT}\n\n"
        f"### User:\n{question}\n\n"
        f"### Assistant:\n"
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
        )
    response = output_ids[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(response, skip_special_tokens=True).strip()

#Key Takeaways

  1. LoRA is extremely parameter-efficient. Training only 3.59% of parameters (7.1M out of 197.8M) was sufficient to teach instruction-following behavior.
  2. QLoRA enables fine-tuning on consumer GPUs. 4-bit quantization reduced memory to under 4 GB, making the 330M model trainable on a single RTX 4090.
  3. Unsloth delivers on its promise. The entire fine-tuning run completed in ~43 minutes with 53.4 samples/sec throughput — 2-5x faster than vanilla HuggingFace.
  4. Format conversion is straightforward. Because MedSLM uses identical architecture to LLaMA, converting to HuggingFace format required only a layer-name mapping.
  5. Small models have limitations. Despite learning instruction-following format, the 330M parameter model still hallucinates and produces factual errors — highlighting the trade-off between model size and accuracy.
  6. Dual export is best practice. Saving both LoRA adapters (17.8 MB) and the merged model (1.32 GB) provides flexibility for different deployment scenarios.

#Resources

Available Blogs

Explore other posts in this series.