Implement a function that performs a single RLHF (Reinforcement Learning from Human Feedback) weight update using the policy gradient method.
Given:
weights: A 1D numpy array of current model weightsrewards: A 1D numpy array of rewards from a reward model for each samplepolicy_log_probs: A 1D numpy array of log probabilities from the current policyref_log_probs: A 1D numpy array of log probabilities from the reference modellog_prob_grads: A 2D numpy array of shape(batch_size, num_weights)where each row is the gradient of that sample's log probability with respect to the weightsbeta: KL penalty coefficientlr: Learning rate
In RLHF, the goal is to update the model weights so that the policy maximizes reward from human feedback while staying close to a reference model. The KL divergence between the policy and reference can be estimated per sample from log probability differences. This KL estimate is used to penalize the reward, producing an adjusted reward signal. The policy gradient is then computed by weighting each sample's gradient by its adjusted reward, and the weights are updated via gradient ascent.
Return the updated weights as a 1D numpy array.
Examples
Example 1:
Input:
weights = np.array([1.0, 1.0])
rewards = np.array([1.0, 1.0])
policy_log_probs = np.array([-1.5, -1.5])
ref_log_probs = np.array([-1.5, -1.5])
log_prob_grads = np.array([[0.5, -0.5], [0.5, -0.5]])
beta = 1.0
lr = 0.1Output:
[1.05, 0.95]Explanation: When the policy and reference log probs are identical, the KL divergence is zero, so the adjusted reward equals the raw reward (1.0 for both samples). The policy gradient becomes the mean of reward-weighted gradients: mean([1.0 * [0.5, -0.5], 1.0 * [0.5, -0.5]]) = [0.5, -0.5]. The weight update is then weights + 0.1 * [0.5, -0.5] = [1.05, 0.95].
Starter Code
import numpy as np
def rlhf_weight_update(
weights: np.ndarray,
rewards: np.ndarray,
policy_log_probs: np.ndarray,
ref_log_probs: np.ndarray,
log_prob_grads: np.ndarray,
beta: float,
lr: float
) -> np.ndarray:
"""
Perform a single RLHF policy gradient weight update.
Args:
weights: Current model weights, shape (num_weights,)
rewards: Rewards from reward model, shape (batch_size,)
policy_log_probs: Log probs from current policy, shape (batch_size,)
ref_log_probs: Log probs from reference model, shape (batch_size,)
log_prob_grads: Gradient of log probs w.r.t. weights, shape (batch_size, num_weights)
beta: KL penalty coefficient
lr: Learning rate
Returns:
Updated weights as a numpy array, shape (num_weights,)
"""
passPython3
ReadyLines: 1Characters: 0
Ready