Implement a function that applies top-p (nucleus) sampling to a set of logits. Top-p sampling is a popular decoding strategy in language models that dynamically selects the smallest possible set of tokens whose cumulative probability mass meets or exceeds a given threshold p.
Given a list of raw logits and a probability threshold p (between 0 exclusive and 1 inclusive), your function should:
- Convert the logits into a valid probability distribution
- Identify the nucleus: the minimal set of highest-probability tokens whose combined probability is at least p
- Set the probabilities of all tokens outside the nucleus to zero
- Renormalize the remaining probabilities so they sum to 1
Return the resulting filtered and renormalized probability distribution as a list of floats, each rounded to 4 decimal places.
Note: When multiple tokens have identical probabilities, break ties by preferring the token with the smaller index.
Examples
Example 1:
Input:
logits = [1.0, 2.0, 3.0], p = 0.8Output:
[0.0, 0.2689, 0.7311]Explanation: First, compute softmax probabilities: [0.0900, 0.2447, 0.6652]. Sort by descending probability: token 2 (0.6652), token 1 (0.2447), token 0 (0.0900). Compute cumulative sums: [0.6652, 0.9099, 1.0]. The cumulative probability first reaches p=0.8 at the second token (0.9099 >= 0.8), so the nucleus contains tokens {2, 1}. Zero out token 0 and renormalize: [0.0, 0.2447/0.9099, 0.6652/0.9099] = [0.0, 0.2689, 0.7311].
Example 2:
Input:
Hidden test case or specific edge caseOutput:
Correct evaluated resultExplanation: An additional example to demonstrate the robustness of the implementation.
Starter Code
import numpy as np
def top_p_sampling(logits: list[float], p: float) -> list[float]:
"""
Apply top-p (nucleus) sampling to filter a probability distribution.
Args:
logits: Raw unnormalized scores for each token
p: Cumulative probability threshold (0 < p <= 1)
Returns:
Filtered and renormalized probability distribution as a list of floats
"""
passPython3
ReadyLines: 1Characters: 0
Ready