Top-p (Nucleus) Sampling

Medium
LLM

Implement a function that applies top-p (nucleus) sampling to a set of logits. Top-p sampling is a popular decoding strategy in language models that dynamically selects the smallest possible set of tokens whose cumulative probability mass meets or exceeds a given threshold p.

Given a list of raw logits and a probability threshold p (between 0 exclusive and 1 inclusive), your function should:

  1. Convert the logits into a valid probability distribution
  2. Identify the nucleus: the minimal set of highest-probability tokens whose combined probability is at least p
  3. Set the probabilities of all tokens outside the nucleus to zero
  4. Renormalize the remaining probabilities so they sum to 1

Return the resulting filtered and renormalized probability distribution as a list of floats, each rounded to 4 decimal places.

Note: When multiple tokens have identical probabilities, break ties by preferring the token with the smaller index.

Examples

Example 1:
Input: logits = [1.0, 2.0, 3.0], p = 0.8
Output: [0.0, 0.2689, 0.7311]
Explanation: First, compute softmax probabilities: [0.0900, 0.2447, 0.6652]. Sort by descending probability: token 2 (0.6652), token 1 (0.2447), token 0 (0.0900). Compute cumulative sums: [0.6652, 0.9099, 1.0]. The cumulative probability first reaches p=0.8 at the second token (0.9099 >= 0.8), so the nucleus contains tokens {2, 1}. Zero out token 0 and renormalize: [0.0, 0.2447/0.9099, 0.6652/0.9099] = [0.0, 0.2689, 0.7311].
Example 2:
Input: Hidden test case or specific edge case
Output: Correct evaluated result
Explanation: An additional example to demonstrate the robustness of the implementation.

Starter Code

import numpy as np

def top_p_sampling(logits: list[float], p: float) -> list[float]:
    """
    Apply top-p (nucleus) sampling to filter a probability distribution.
    
    Args:
        logits: Raw unnormalized scores for each token
        p: Cumulative probability threshold (0 < p <= 1)
    
    Returns:
        Filtered and renormalized probability distribution as a list of floats
    """
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews