Budget-Constrained RL Loss

Medium
Reinforcement Learning

Implement the budget-constrained reinforcement learning loss function from the Kimi K2 paper. In K2's RL training, a 'Budget Control' mechanism penalizes responses that exceed a token budget to improve inference efficiency. The base RL loss uses a squared advantage form with KL regularization. Your task is to implement the complete loss computation including the budget penalty.

Examples

Example 1:
Input: rewards = [[1.0, 0.5]], log_probs = [[-1.0, -1.5]], old_log_probs = [[-1.2, -1.3]], response_lengths = [[150, 80]], token_budget = 100, kl_coef = 0.1, budget_penalty_coef = 0.01
Output: 0.0004
Explanation: First response (150 tokens) exceeds budget (100) by 50, so penalty = -0.01 * 50 = -0.5. Adjusted rewards become [0.5, 0.5]. Baseline = 0.5, so advantages = [0.0, 0.0]. KL terms = 0.1 * [0.2, -0.2] = [0.02, -0.02]. Loss = mean((0-0.02)², (0-(-0.02))²) = 0.0004

Starter Code

import numpy as np

def rl_budget_loss(
    rewards: np.ndarray,
    log_probs: np.ndarray,
    old_log_probs: np.ndarray,
    response_lengths: np.ndarray,
    token_budget: int,
    kl_coef: float,
    budget_penalty_coef: float
) -> float:
    """
    Compute the budget-constrained RL loss.
    
    The loss combines:
    1. Budget penalty for responses exceeding token_budget
    2. Advantage estimation (adjusted reward - baseline)
    3. KL regularization between current and old policy
    
    Loss formula: E[(advantage - kl_term)^2]
    
    Args:
        rewards: Shape (batch_size, K) - rewards for K samples per prompt
        log_probs: Shape (batch_size, K) - log π_θ(y|x) current policy
        old_log_probs: Shape (batch_size, K) - log π_old(y|x) old policy
        response_lengths: Shape (batch_size, K) - token lengths of responses
        token_budget: Maximum allowed tokens before penalty
        kl_coef: τ coefficient for KL regularization
        budget_penalty_coef: λ coefficient for budget penalty
        
    Returns:
        Scalar loss value (float)
    """
    # Your code here
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews