Implement the unbiased KL divergence estimator used in GRPO (Group Relative Policy Optimization). This estimator computes the KL divergence between the current policy and a reference policy for each sample, which is then used as a regularization term to prevent the policy from deviating too far from the reference.
Examples
Example 1:
Input:
pi_theta = np.array([0.8]), pi_ref = np.array([0.4])Output:
np.array([0.1931])Explanation: ratio = 0.4/0.8 = 0.5. KL = 0.5 - log(0.5) - 1 = 0.5 - (-0.693) - 1 = 0.193. This penalizes the policy for assigning higher probability than the reference.
Starter Code
import numpy as np
def kl_divergence_estimator(pi_theta: np.ndarray, pi_ref: np.ndarray) -> np.ndarray:
"""
Compute the unbiased KL divergence estimator used in GRPO.
Formula: D_KL = (pi_ref / pi_theta) - log(pi_ref / pi_theta) - 1
Args:
pi_theta: Current policy probabilities for each sample
pi_ref: Reference policy probabilities for each sample
Returns:
Array of KL divergence estimates (one per sample)
"""
# Your code here
passPython3
ReadyLines: 1Characters: 0
Ready