Fine-Tune Model Weights with RLHF Policy Gradient

Medium
Reinforcement Learning

Implement a function that performs a single RLHF (Reinforcement Learning from Human Feedback) weight update using the policy gradient method.

Given:

  • weights: A 1D numpy array of current model weights
  • rewards: A 1D numpy array of rewards from a reward model for each sample
  • policy_log_probs: A 1D numpy array of log probabilities from the current policy
  • ref_log_probs: A 1D numpy array of log probabilities from the reference model
  • log_prob_grads: A 2D numpy array of shape (batch_size, num_weights) where each row is the gradient of that sample's log probability with respect to the weights
  • beta: KL penalty coefficient
  • lr: Learning rate

In RLHF, the goal is to update the model weights so that the policy maximizes reward from human feedback while staying close to a reference model. The KL divergence between the policy and reference can be estimated per sample from log probability differences. This KL estimate is used to penalize the reward, producing an adjusted reward signal. The policy gradient is then computed by weighting each sample's gradient by its adjusted reward, and the weights are updated via gradient ascent.

Return the updated weights as a 1D numpy array.

Examples

Example 1:
Input: weights = np.array([1.0, 1.0]) rewards = np.array([1.0, 1.0]) policy_log_probs = np.array([-1.5, -1.5]) ref_log_probs = np.array([-1.5, -1.5]) log_prob_grads = np.array([[0.5, -0.5], [0.5, -0.5]]) beta = 1.0 lr = 0.1
Output: [1.05, 0.95]
Explanation: When the policy and reference log probs are identical, the KL divergence is zero, so the adjusted reward equals the raw reward (1.0 for both samples). The policy gradient becomes the mean of reward-weighted gradients: mean([1.0 * [0.5, -0.5], 1.0 * [0.5, -0.5]]) = [0.5, -0.5]. The weight update is then weights + 0.1 * [0.5, -0.5] = [1.05, 0.95].

Starter Code

import numpy as np

def rlhf_weight_update(
    weights: np.ndarray,
    rewards: np.ndarray,
    policy_log_probs: np.ndarray,
    ref_log_probs: np.ndarray,
    log_prob_grads: np.ndarray,
    beta: float,
    lr: float
) -> np.ndarray:
    """
    Perform a single RLHF policy gradient weight update.
    
    Args:
        weights: Current model weights, shape (num_weights,)
        rewards: Rewards from reward model, shape (batch_size,)
        policy_log_probs: Log probs from current policy, shape (batch_size,)
        ref_log_probs: Log probs from reference model, shape (batch_size,)
        log_prob_grads: Gradient of log probs w.r.t. weights, shape (batch_size, num_weights)
        beta: KL penalty coefficient
        lr: Learning rate
    
    Returns:
        Updated weights as a numpy array, shape (num_weights,)
    """
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews