Training MedSLM-SFT: Supervised Fine-Tuning for Medical Instruction Following
In this post, we fine-tune MedSLM using LoRA and QLoRA to transform it from a text completion engine into a medical question-answering assistant. We use Unsloth for 2-5x faster training with 70% less memory.
In the previous post, we curated a high-quality medical SFT dataset. Now we use it to instruction-tune MedSLM — teaching it to follow a structured question-answering format using parameter-efficient LoRA adapters.
#Why LoRA Instead of Full Fine-Tuning?
Full fine-tuning updates all ~330M parameters, requiring significant GPU memory and risking catastrophic forgetting of the medical knowledge learned during pre-training. LoRA (Low-Rank Adaptation) freezes the pre-trained weights and injects small, trainable low-rank matrices into specific layers. QLoRA goes further by quantizing the frozen base weights to 4-bit precision, reducing memory even more.
| Aspect | Full Fine-Tuning | LoRA | QLoRA |
|---|---|---|---|
| Trainable parameters | 100% (~330M) | ~1-5% (~5-15M) | ~1-5% (~5-15M) |
| GPU memory | ~8-12 GB | ~4-6 GB | ~2-4 GB |
| Training speed | Baseline | 2-3x faster | 2-5x faster |
| Catastrophic forgetting risk | High | Low | Low |
| Storage per checkpoint | ~1.3 GB | ~30-80 MB | ~30-80 MB |
#Pipeline Overview
- Model Conversion — Convert custom MedSLM weights to HuggingFace LLaMA-compatible format
- Load Model with Unsloth — Load the converted model with optimized CUDA kernels
- Load SFT Dataset — Load the instruction dataset from HuggingFace
- LoRA Adapter Configuration — Configure and attach LoRA adapters to target modules
- Training with SFTTrainer — Fine-tune using TRL's SFTTrainer
- Evaluation — Perplexity, sample generation, and quality assessment
- Merge LoRA Adapters — Merge trained adapters back into the base model
- Save & Upload to HuggingFace — Push the fine-tuned model to the Hub
#Step 1: Weight Conversion to HuggingFace Format
Our custom MedSLM model is architecturally identical to LLaMA — it uses the same building blocks (RMSNorm, RoPE, SwiGLU, GQA). The only difference is the layer naming convention. We map MedSLM weight names to the standard HuggingFace LlamaForCausalLM naming so the model can be used with Unsloth, PEFT, and the entire HF ecosystem.
| MedSLM Layer | HuggingFace LLaMA Layer | Description |
|---|---|---|
| tok_emb.weight | model.embed_tokens.weight | Token embeddings |
| blocks.{i}.attn.wq.weight | model.layers.{i}.self_attn.q_proj.weight | Query projection |
| blocks.{i}.attn.wk.weight | model.layers.{i}.self_attn.k_proj.weight | Key projection |
| blocks.{i}.attn.wv.weight | model.layers.{i}.self_attn.v_proj.weight | Value projection |
| blocks.{i}.attn.wo.weight | model.layers.{i}.self_attn.o_proj.weight | Output projection |
| blocks.{i}.ffn.w_gate.weight | model.layers.{i}.mlp.gate_proj.weight | SwiGLU gate |
| blocks.{i}.ffn.w_up.weight | model.layers.{i}.mlp.up_proj.weight | SwiGLU up |
| blocks.{i}.ffn.w_down.weight | model.layers.{i}.mlp.down_proj.weight | SwiGLU down |
| norm_f.weight | model.norm.weight | Final RMSNorm |
def convert_medslm_to_llama(medslm_state_dict: dict) -> dict:
llama_state_dict = {}
for medslm_name, weight in medslm_state_dict.items():
if medslm_name in WEIGHT_MAP:
llama_state_dict[WEIGHT_MAP[medslm_name]] = weight
continue
for medslm_suffix, llama_suffix in LAYER_WEIGHT_MAP.items():
if medslm_suffix in medslm_name:
layer_idx = medslm_name.split(".")[1]
llama_name = f"model.layers.{layer_idx}.{llama_suffix}"
llama_state_dict[llama_name] = weight
break
return llama_state_dictThe conversion maps all 218 weight tensors (329,910,272 parameters) with zero skipped. We verify the conversion by running a forward pass through the instantiated LlamaForCausalLM, confirming valid output logits in the range [-10.40, 11.36].
#Step 2: Load with Unsloth
We load the converted model using Unsloth's FastLanguageModel, which provides 2-5x faster training via fused CUDA kernels for attention, RoPE, and cross-entropy, 70% less memory through optimized gradient checkpointing, and automatic 4-bit quantization via bitsandbytes. We set load_in_4bit=True for QLoRA, quantizing the base model weights to 4-bit NF4 format while keeping the LoRA adapters in full precision.
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="medslm_hf_converted",
max_seq_length=1024,
dtype=None, # Auto-detect (bf16 if supported)
load_in_4bit=True, # QLoRA: 4-bit quantized base
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"Parameters (4-bit)
Vocab Size
Precision
GPU
#Step 3: LoRA Adapter Configuration
LoRA works by freezing the pre-trained model weights and injecting small, trainable low-rank matrices into specific layers. Instead of updating a full weight matrix W (shape d x d), LoRA learns two small matrices A (shape d x r) and B (shape r x d), where r << d. The update becomes W_new = W_frozen + (alpha/r) * B * A. We attach LoRA adapters to all linear layers in the transformer — both attention projections and SwiGLU FFN projections — giving LoRA maximum expressiveness while remaining parameter-efficient.
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=32,
lora_dropout=0.0,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
)| Parameter | Value | Description |
|---|---|---|
| Rank (r) | 16 | Rank of low-rank matrices — higher = more capacity |
| Alpha | 32 | Scaling factor — effective scale = alpha/r = 2.0 |
| Dropout | 0.0 | Unsloth-optimized: 0 for fused kernels |
| Target modules | q, k, v, o, gate, up, down | All linear layers get LoRA adapters |
| Bias | none | No bias terms in LoRA layers |
Total Parameters
Trainable (LoRA)
Frozen (Base)
Memory Savings
#Step 4: Training Configuration
We use TRL's SFTTrainer — a specialized trainer for supervised fine-tuning that handles tokenization, sequence packing (combining multiple short examples into single sequences for efficiency), and loss computation. The learning rate of 2e-4 is standard for LoRA SFT — higher than pre-training since we're only updating adapters. Sequence packing is enabled to maximize GPU utilization given our average example length of ~180 tokens against a 1,024-token context window.
| Hyperparameter | Value | Reasoning |
|---|---|---|
| Learning rate | 2e-4 | Standard for LoRA SFT |
| LR scheduler | Cosine | Smooth decay for stable convergence |
| Warmup ratio | 5% | Gradual ramp-up for stability |
| Batch size (per device) | 4 | Kept small for memory efficiency |
| Gradient accumulation | 8 | Effective batch size = 32 |
| Epochs | 3 | Enough to learn the format without overfitting |
| Weight decay | 0.01 | Light regularization |
| Max gradient norm | 1.0 | Gradient clipping for stability |
| Optimizer | AdamW (8-bit) | Memory-efficient optimizer |
| Sequence packing | Enabled | Pack short examples for efficiency |
| Max sequence length | 1,024 | Matches pre-training context window |
| Precision | BF16 | FP16 fallback if not supported |
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=sft_dataset["train"],
eval_dataset=sft_dataset["validation"],
args=TrainingArguments(
output_dir="medslm_sft_output",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.01,
bf16=True,
optim="adamw_8bit",
eval_strategy="steps",
eval_steps=250,
save_strategy="steps",
save_steps=500,
load_best_model_at_end=True,
),
dataset_text_field="text",
max_seq_length=1024,
packing=True,
)#Training Results
Training completed in approximately 43 minutes on a single NVIDIA RTX 4090. The training loss decreased steadily from 3.06 to 2.47, with validation loss converging to 2.21 — no significant overfitting gap. The base model weights remained frozen throughout; only the 7.1M LoRA adapter parameters were updated.
Total Steps
Final Train Loss
Final Val Loss
Runtime
Throughput
GPU
| Step | Training Loss | Validation Loss |
|---|---|---|
| 250 | 3.0606 | 2.9208 |
| 500 | 2.7814 | 2.6572 |
| 1,000 | 2.6146 | 2.4686 |
| 1,500 | 2.3677 | 2.3664 |
| 2,000 | 2.2951 | 2.3084 |
| 2,500 | 2.3200 | 2.2627 |
| 3,000 | 2.1968 | 2.2314 |
| 3,500 | 2.2834 | 2.2160 |
| 4,000 | 2.2565 | 2.2089 |
| 4,329 | 2.1976 | 2.2084 |

Training and validation loss over 4,329 steps — steady convergence with no significant overfitting
#Evaluation
We evaluate the fine-tuned model from multiple angles: test-set perplexity, sample generation quality, and response length statistics. The model generates coherent medical responses that follow the instruction format, though as a 330M parameter model it is still prone to occasional hallucinations and repetition.
Generation statistics on 50 test examples show a mean response length of 205 words, median of 209 words, with a minimum of 156 words and maximum of 234 words — indicating the model has learned to produce consistently substantive answers rather than short or degenerate outputs.
Sample Generations
Here are example responses from the fine-tuned model to medical questions. While the model demonstrates instruction-following ability and medical vocabulary, the small parameter count means responses sometimes contain factual inaccuracies — reinforcing that this is a research model, not suitable for clinical use.
Q: How is hypertension diagnosed and treated?
A: The first-line treatment for hypertension is a combination of
medication and lifestyle changes. This means that the patient's
blood pressure can be improved by taking medications such as
benzodiazepines or other medications. In some cases, medications
such as diuretics may also be used to treat hypertension...
Q: What causes acute myocardial infarction?
A: The exact cause of acute myocardial infarction is not fully
understood, but it is thought to be related to the severity of
the symptoms. Treatment for acute myocardial infarction may
involve medications to suppress the immune system and reduce
the risk of developing myocardial infarction...#Merging & Export
After training, the LoRA adapters exist as separate lightweight matrices (~17.8 MB). For deployment, we provide two options: keeping adapters separate (for flexibility and easy swapping) and merging adapters into the base model (for standalone deployment without PEFT dependency). The merge operation computes W_merged = W_base + (alpha/r) * B @ A for each adapted layer, producing a single 1.32 GB model file.
# Save LoRA adapters separately (~17.8 MB)
model.save_pretrained("medslm_sft_output/lora_adapters")
tokenizer.save_pretrained("medslm_sft_output/lora_adapters")
# Merge adapters into base model (~1.32 GB)
model.save_pretrained_merged(
"medslm_sft_output/merged_model",
tokenizer,
save_method="merged_16bit",
)| Output | Size | Repository |
|---|---|---|
| LoRA adapters | 17.8 MB | Saminx22/MedSLM-SFT-LoRA |
| Merged model | 1.32 GB | Saminx22/MedSLM-SFT |
#Inference
The fine-tuned model can be loaded in two ways: directly as the merged model (simplest approach), or as the base model plus LoRA adapters via PEFT (more flexible, allows swapping adapters). Both methods produce identical results. The prompt must follow the exact instruction template used during training.
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
# Method 1: Load merged model
model = AutoModelForCausalLM.from_pretrained(
"Saminx22/MedSLM-SFT",
torch_dtype=torch.float16,
device_map="auto",
)
# Method 2: Load base + LoRA adapters
base_model = AutoModelForCausalLM.from_pretrained(
"Saminx22/MedSLM",
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(base_model, "Saminx22/MedSLM-SFT-LoRA")SYSTEM_PROMPT = (
"You are a medical AI assistant. "
"Provide accurate, evidence-based answers to medical questions."
)
def ask(question: str, max_new_tokens: int = 300) -> str:
prompt = (
f"### System:\n{SYSTEM_PROMPT}\n\n"
f"### User:\n{question}\n\n"
f"### Assistant:\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
)
response = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(response, skip_special_tokens=True).strip()#Key Takeaways
- LoRA is extremely parameter-efficient. Training only 3.59% of parameters (7.1M out of 197.8M) was sufficient to teach instruction-following behavior.
- QLoRA enables fine-tuning on consumer GPUs. 4-bit quantization reduced memory to under 4 GB, making the 330M model trainable on a single RTX 4090.
- Unsloth delivers on its promise. The entire fine-tuning run completed in ~43 minutes with 53.4 samples/sec throughput — 2-5x faster than vanilla HuggingFace.
- Format conversion is straightforward. Because MedSLM uses identical architecture to LLaMA, converting to HuggingFace format required only a layer-name mapping.
- Small models have limitations. Despite learning instruction-following format, the 330M parameter model still hallucinates and produces factual errors — highlighting the trade-off between model size and accuracy.
- Dual export is best practice. Saving both LoRA adapters (17.8 MB) and the merged model (1.32 GB) provides flexibility for different deployment scenarios.
#Resources
Available Blogs
Explore other posts in this series.

Building a High-Quality Medical Pretraining Dataset for Small Language Models
Large language models like GPT-4 or Gemini are trained on trillions of tokens scraped from the open web. But when your goal is a Small Language Model (SLM) with only ~300 million parameters, targeted at the medical domain, quality matters far more than quantity.

Building MedSLM: A 330M Parameter Medical Language Model
In this post, we build MedSLM - a 330M parameter transformer trained from scratch on our curated medical dataset. We implement modern architecture choices like RMSNorm, Rotary Positional Embeddings, SwiGLU activations, and Grouped-Query Attention.

Curating a Medical SFT Dataset: From Raw QA Pairs to Instruction-Ready Data
In this post, we build a high-quality Supervised Fine-Tuning (SFT) dataset for medical question answering. We combine three curated medical QA sources, apply multi-stage quality filtering, perform MinHash near-duplicate removal, and produce a clean 51K-example instruction dataset.