Implement a function to evaluate multiple choice question answering using log-probability scoring, as commonly used in MMLU (Massive Multitask Language Understanding) benchmark evaluation.
Given a set of multiple choice questions where each question has log-probabilities assigned to each answer choice by a language model, your task is to:
- Determine the predicted answer for each question (the choice with the highest log-probability)
- Calculate the overall accuracy (proportion of questions where the prediction matches the correct answer)
- Compute the average probability assigned to the correct answer across all questions (convert log-probabilities to probabilities first)
The function should take:
log_probs: A list of lists where each inner list contains the log-probabilities for each answer choice for a questioncorrect_answers: A list of integers representing the correct answer index (0-indexed) for each question
Return a dictionary with keys 'accuracy', 'predictions', and 'avg_correct_prob'.
Note: When converting log-probabilities to probabilities, ensure numerical stability in your implementation.
Examples
Example 1:
Input:
log_probs = [[-1.0, -2.0, -3.0, -4.0], [-2.0, -1.0, -3.0, -4.0]], correct_answers = [0, 1]Output:
{'accuracy': 1.0, 'predictions': [0, 1], 'avg_correct_prob': 0.6439}Explanation: For question 1, choice 0 has the highest log-prob (-1.0), matching the correct answer 0. For question 2, choice 1 has the highest log-prob (-1.0), matching correct answer 1. Both predictions are correct, so accuracy is 1.0. Converting log-probs to probabilities via softmax and averaging the probability of correct answers gives 0.6439.
Example 2:
Input:
Hidden test case or specific edge caseOutput:
Correct evaluated resultExplanation: An additional example to demonstrate the robustness of the implementation.
Starter Code
import numpy as np
def mmlu_log_prob_score(log_probs: list, correct_answers: list) -> dict:
"""
Compute MMLU-style log-probability scoring metrics.
Args:
log_probs: List of lists, where each inner list contains
log-probabilities for each answer choice
correct_answers: List of correct answer indices (0-indexed)
Returns:
Dictionary with 'accuracy', 'predictions', and 'avg_correct_prob'
"""
passPython3
ReadyLines: 1Characters: 0
Ready