RAFT: Iterative Reward-Ranked Fine-Tuning Loop

Medium
Reinforcement Learning

Implement the RAFT (Reward rAnked FineTuning) alignment loop that iterates over T stages to progressively improve a generative model.

At each stage t, RAFT performs three steps:

Step 1 — Data Collection: A batch of b prompts is sampled, and K candidate responses are generated per prompt from the current model. (These are provided as input.)

Step 2 — Data Ranking: For each prompt, use the reward model to score all K candidates, then select the response with the highest reward: y* = argmax_j r(x, y_j). Collect these best responses into a batch B of size b.

Step 3 — Model Fine-Tuning: Compute the SFT (supervised fine-tuning) loss on batch B, defined as the mean negative log-likelihood over all tokens in the selected responses.

Given:

  • stages_log_probs: Token-level log probabilities across T stages. Shape: [T][num_prompts][K][num_tokens]. stages_log_probs[t][i][j][k] is the log prob of token k in response j for prompt i at stage t.
  • stages_rewards: Reward scores across T stages. Shape: [T][num_prompts][K]. stages_rewards[t][i][j] is the reward for response j of prompt i at stage t.

Return a list of T SFT losses (one per stage), each rounded to 4 decimal places. As the model improves across stages, the losses should generally decrease.

Examples

Example 1:
Input: stages_log_probs = [ [[[-2.0, -1.8], [-1.5, -1.6]]], # Stage 0: weak model [[[-0.6, -0.5], [-0.7, -0.8]]], # Stage 1: improving [[[-0.2, -0.1], [-0.3, -0.4]]], # Stage 2: strong model ] stages_rewards = [ [[0.2, 0.4]], # Stage 0 [[0.75, 0.7]], # Stage 1 [[0.95, 0.8]], # Stage 2 ]
Output: [1.55, 0.55, 0.15]
Explanation: Stage 0: Response 1 has higher reward (0.4). Its log probs are [-1.5, -1.6]. NLL = 1.5 + 1.6 = 3.1. Loss = 3.1/2 = 1.55. Stage 1: Response 0 has higher reward (0.75). Its log probs are [-0.6, -0.5]. NLL = 0.6 + 0.5 = 1.1. Loss = 1.1/2 = 0.55. Stage 2: Response 0 has higher reward (0.95). Its log probs are [-0.2, -0.1]. NLL = 0.2 + 0.1 = 0.3. Loss = 0.3/2 = 0.15. The decreasing losses [1.55, 0.55, 0.15] show the model improving across RAFT stages.

Starter Code

import numpy as np

def raft_alignment(
    stages_log_probs: list[list[list[list[float]]]],
    stages_rewards: list[list[list[float]]],
) -> list[float]:
    """
    Simulate the RAFT alignment loop over T stages.

    At each stage:
      Step 1: Candidates are provided (K responses per prompt).
      Step 2: For each prompt, select the response with the highest reward.
      Step 3: Compute SFT loss (mean NLL) on the selected responses.

    Args:
        stages_log_probs: Log probs across T stages.
            Shape: [T][num_prompts][K][num_tokens]
        stages_rewards: Rewards across T stages.
            Shape: [T][num_prompts][K]

    Returns:
        List of T SFT losses, each rounded to 4 decimal places.
    """
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews