Implement the RAFT (Reward rAnked FineTuning) alignment loop that iterates over T stages to progressively improve a generative model.
At each stage t, RAFT performs three steps:
Step 1 — Data Collection: A batch of b prompts is sampled, and K candidate responses are generated per prompt from the current model. (These are provided as input.)
Step 2 — Data Ranking: For each prompt, use the reward model to score all K candidates, then select the response with the highest reward: y* = argmax_j r(x, y_j). Collect these best responses into a batch B of size b.
Step 3 — Model Fine-Tuning: Compute the SFT (supervised fine-tuning) loss on batch B, defined as the mean negative log-likelihood over all tokens in the selected responses.
Given:
stages_log_probs: Token-level log probabilities across T stages. Shape: [T][num_prompts][K][num_tokens].stages_log_probs[t][i][j][k]is the log prob of token k in response j for prompt i at stage t.stages_rewards: Reward scores across T stages. Shape: [T][num_prompts][K].stages_rewards[t][i][j]is the reward for response j of prompt i at stage t.
Return a list of T SFT losses (one per stage), each rounded to 4 decimal places. As the model improves across stages, the losses should generally decrease.
Examples
stages_log_probs = [
[[[-2.0, -1.8], [-1.5, -1.6]]], # Stage 0: weak model
[[[-0.6, -0.5], [-0.7, -0.8]]], # Stage 1: improving
[[[-0.2, -0.1], [-0.3, -0.4]]], # Stage 2: strong model
]
stages_rewards = [
[[0.2, 0.4]], # Stage 0
[[0.75, 0.7]], # Stage 1
[[0.95, 0.8]], # Stage 2
][1.55, 0.55, 0.15]Starter Code
import numpy as np
def raft_alignment(
stages_log_probs: list[list[list[list[float]]]],
stages_rewards: list[list[list[float]]],
) -> list[float]:
"""
Simulate the RAFT alignment loop over T stages.
At each stage:
Step 1: Candidates are provided (K responses per prompt).
Step 2: For each prompt, select the response with the highest reward.
Step 3: Compute SFT loss (mean NLL) on the selected responses.
Args:
stages_log_probs: Log probs across T stages.
Shape: [T][num_prompts][K][num_tokens]
stages_rewards: Rewards across T stages.
Shape: [T][num_prompts][K]
Returns:
List of T SFT losses, each rounded to 4 decimal places.
"""
pass