Implement the Direct Preference Optimization (DPO) loss function used in aligning large language models with human preferences.
DPO is a method that directly optimizes a language model policy using preference pairs (chosen vs rejected responses) without requiring a separate reward model. It leverages the log-probabilities of responses under both the current policy model and a frozen reference model.
Your function should take:
log_probs_chosen_policy: Log-probabilities of preferred responses under the policy modellog_probs_rejected_policy: Log-probabilities of dispreferred responses under the policy modellog_probs_chosen_ref: Log-probabilities of preferred responses under the reference modellog_probs_rejected_ref: Log-probabilities of dispreferred responses under the reference modelbeta: A temperature parameter controlling the strength of the KL constraint
Your function should return a dictionary with:
'loss': The average DPO loss across the batch, rounded to 4 decimal places'chosen_rewards': A list of implicit reward values for the chosen responses, rounded to 4 decimal places'rejected_rewards': A list of implicit reward values for the rejected responses, rounded to 4 decimal places
The implicit reward for a response is defined as the scaled log-ratio between the policy and reference probabilities. The DPO loss encourages the policy to assign higher implicit reward to chosen responses compared to rejected ones.
All inputs are lists of floats with the same length (batch size).
Examples
log_probs_chosen_policy=[-2.0, -1.5], log_probs_rejected_policy=[-3.0, -2.5], log_probs_chosen_ref=[-2.5, -2.0], log_probs_rejected_ref=[-2.5, -2.0], beta=0.5{'loss': 0.4741, 'chosen_rewards': [0.25, 0.25], 'rejected_rewards': [-0.25, -0.25]}Hidden test case or specific edge caseCorrect evaluated resultStarter Code
import numpy as np
def dpo_loss(log_probs_chosen_policy: list, log_probs_rejected_policy: list,
log_probs_chosen_ref: list, log_probs_rejected_ref: list,
beta: float) -> dict:
"""
Compute the Direct Preference Optimization (DPO) loss.
Args:
log_probs_chosen_policy: Log-probs of chosen responses under policy
log_probs_rejected_policy: Log-probs of rejected responses under policy
log_probs_chosen_ref: Log-probs of chosen responses under reference model
log_probs_rejected_ref: Log-probs of rejected responses under reference model
beta: Temperature parameter for KL constraint strength
Returns:
Dictionary with 'loss', 'chosen_rewards', and 'rejected_rewards'
"""
# Your code here
pass