In production ML systems, detecting when input feature distributions change (drift) between training and production is crucial for maintaining model performance. The Population Stability Index (PSI) is a widely-used metric in MLOps for quantifying distribution shifts.
Write a function detect_feature_drift(reference_data, production_data, num_bins) that:
- Takes a reference distribution (e.g., training data feature values) and a production distribution (current incoming data)
- Computes the PSI to measure how much the production distribution has shifted from the reference
- Returns a dictionary with the PSI value and drift assessment
The function should return a dictionary containing:
psi: The calculated Population Stability Index (rounded to 4 decimal places)drift_detected: Boolean indicating if drift is detected (PSI >= 0.1)drift_level: One of 'none' (PSI < 0.1), 'moderate' (0.1 <= PSI < 0.25), or 'significant' (PSI >= 0.25)
If either input list is empty, return an empty dictionary.
Note: When computing bin proportions, use a small epsilon value (0.0001) to replace zero proportions to avoid numerical issues with logarithms.
Examples
Example 1:
Input:
reference_data = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], production_data = [3, 3, 4, 4, 5, 5, 6, 6, 7, 7], num_bins = 5Output:
{'psi': 0.1871, 'drift_detected': True, 'drift_level': 'moderate'}Explanation: The reference data is concentrated in range [1,5] while production data has shifted to [3,7]. Using 5 bins, we compute the proportion of data in each bin for both distributions. The PSI formula sums (prod_pct - ref_pct) * ln(prod_pct / ref_pct) across all bins. The resulting PSI of 0.1871 falls in the moderate drift range (0.1-0.25), indicating the distribution has shifted enough to warrant monitoring.
Starter Code
import numpy as np
def detect_feature_drift(reference_data: list, production_data: list, num_bins: int = 10) -> dict:
"""
Detect feature drift using Population Stability Index (PSI).
Args:
reference_data: List of feature values from reference distribution (e.g., training)
production_data: List of feature values from production distribution
num_bins: Number of bins for histogram comparison
Returns:
dict with 'psi', 'drift_detected', and 'drift_level'
"""
passPython3
ReadyLines: 1Characters: 0
Ready