MMLU Log-Probability Scoring

Medium
LLM

Implement a function to evaluate multiple choice question answering using log-probability scoring, as commonly used in MMLU (Massive Multitask Language Understanding) benchmark evaluation.

Given a set of multiple choice questions where each question has log-probabilities assigned to each answer choice by a language model, your task is to:

  1. Determine the predicted answer for each question (the choice with the highest log-probability)
  2. Calculate the overall accuracy (proportion of questions where the prediction matches the correct answer)
  3. Compute the average probability assigned to the correct answer across all questions (convert log-probabilities to probabilities first)

The function should take:

  • log_probs: A list of lists where each inner list contains the log-probabilities for each answer choice for a question
  • correct_answers: A list of integers representing the correct answer index (0-indexed) for each question

Return a dictionary with keys 'accuracy', 'predictions', and 'avg_correct_prob'.

Note: When converting log-probabilities to probabilities, ensure numerical stability in your implementation.

Examples

Example 1:
Input: log_probs = [[-1.0, -2.0, -3.0, -4.0], [-2.0, -1.0, -3.0, -4.0]], correct_answers = [0, 1]
Output: {'accuracy': 1.0, 'predictions': [0, 1], 'avg_correct_prob': 0.6439}
Explanation: For question 1, choice 0 has the highest log-prob (-1.0), matching the correct answer 0. For question 2, choice 1 has the highest log-prob (-1.0), matching correct answer 1. Both predictions are correct, so accuracy is 1.0. Converting log-probs to probabilities via softmax and averaging the probability of correct answers gives 0.6439.
Example 2:
Input: Hidden test case or specific edge case
Output: Correct evaluated result
Explanation: An additional example to demonstrate the robustness of the implementation.

Starter Code

import numpy as np

def mmlu_log_prob_score(log_probs: list, correct_answers: list) -> dict:
    """
    Compute MMLU-style log-probability scoring metrics.
    
    Args:
        log_probs: List of lists, where each inner list contains 
                   log-probabilities for each answer choice
        correct_answers: List of correct answer indices (0-indexed)
    
    Returns:
        Dictionary with 'accuracy', 'predictions', and 'avg_correct_prob'
    """
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews