Feature Drift Detection using Population Stability Index

Medium
MLE Interview Prep

In production ML systems, detecting when input feature distributions change (drift) between training and production is crucial for maintaining model performance. The Population Stability Index (PSI) is a widely-used metric in MLOps for quantifying distribution shifts.

Write a function detect_feature_drift(reference_data, production_data, num_bins) that:

  1. Takes a reference distribution (e.g., training data feature values) and a production distribution (current incoming data)
  2. Computes the PSI to measure how much the production distribution has shifted from the reference
  3. Returns a dictionary with the PSI value and drift assessment

The function should return a dictionary containing:

  • psi: The calculated Population Stability Index (rounded to 4 decimal places)
  • drift_detected: Boolean indicating if drift is detected (PSI >= 0.1)
  • drift_level: One of 'none' (PSI < 0.1), 'moderate' (0.1 <= PSI < 0.25), or 'significant' (PSI >= 0.25)

If either input list is empty, return an empty dictionary.

Note: When computing bin proportions, use a small epsilon value (0.0001) to replace zero proportions to avoid numerical issues with logarithms.

Examples

Example 1:
Input: reference_data = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], production_data = [3, 3, 4, 4, 5, 5, 6, 6, 7, 7], num_bins = 5
Output: {'psi': 0.1871, 'drift_detected': True, 'drift_level': 'moderate'}
Explanation: The reference data is concentrated in range [1,5] while production data has shifted to [3,7]. Using 5 bins, we compute the proportion of data in each bin for both distributions. The PSI formula sums (prod_pct - ref_pct) * ln(prod_pct / ref_pct) across all bins. The resulting PSI of 0.1871 falls in the moderate drift range (0.1-0.25), indicating the distribution has shifted enough to warrant monitoring.

Starter Code

import numpy as np

def detect_feature_drift(reference_data: list, production_data: list, num_bins: int = 10) -> dict:
    """
    Detect feature drift using Population Stability Index (PSI).
    
    Args:
        reference_data: List of feature values from reference distribution (e.g., training)
        production_data: List of feature values from production distribution
        num_bins: Number of bins for histogram comparison
    
    Returns:
        dict with 'psi', 'drift_detected', and 'drift_level'
    """
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews