Implement a function to monitor changes in model prediction distributions between a reference (baseline) period and a current period. This is a critical MLOps task for detecting model drift in production.
Given two lists of prediction scores (probabilities between 0 and 1), compute the following monitoring metrics:
- Mean Shift: The difference between the mean of current predictions and the mean of reference predictions
- Standard Deviation Ratio: The ratio of current standard deviation to reference standard deviation
- Jensen-Shannon Divergence: A symmetric measure of distribution similarity based on histogram comparison
- Drift Detected: A boolean flag indicating if JS divergence exceeds 0.1 (significant drift threshold)
For the Jensen-Shannon divergence calculation:
- Create histograms using n_bins equally spaced bins between 0 and 1
- Apply Laplace smoothing to handle empty bins: P(bin) = (count + 1) / (total + n_bins)
- Compute JS divergence as the average of KL divergences from each distribution to their mixture
Return a dictionary with keys 'mean_shift', 'std_ratio', 'js_divergence', and 'drift_detected'.
Examples
Example 1:
Input:
reference_preds = [0.1, 0.2, 0.3, 0.4, 0.5]
current_preds = [0.5, 0.6, 0.7, 0.8, 0.9]
n_bins = 5Output:
{'mean_shift': 0.4, 'std_ratio': 1.0, 'js_divergence': 0.2939, 'drift_detected': True}Explanation: The reference predictions have mean 0.3 and the current predictions have mean 0.7, giving a mean_shift of 0.4. Both have the same spread (std = 0.1414), so std_ratio = 1.0. The JS divergence of 0.2939 exceeds the 0.1 threshold, indicating significant distribution drift.
Starter Code
import numpy as np
def monitor_prediction_distribution(reference_preds: list, current_preds: list, n_bins: int = 10) -> dict:
"""
Monitor prediction distribution changes between reference and current predictions.
Args:
reference_preds: List of reference prediction scores (floats between 0 and 1)
current_preds: List of current prediction scores (floats between 0 and 1)
n_bins: Number of bins for histogram comparison
Returns:
Dictionary with keys: 'mean_shift', 'std_ratio', 'js_divergence', 'drift_detected'
"""
# Your code here
passPython3
ReadyLines: 1Characters: 0
Ready