Implement Request Batching for Inference

Medium
MLE Interview Prep

Implement a request batching function for ML model inference. In production ML systems, batching multiple inference requests together improves throughput by leveraging parallel processing capabilities of hardware accelerators.

Your function should group incoming requests into batches based on two constraints:

  1. Maximum batch size: A batch should not exceed this number of requests
  2. Maximum wait time: The time elapsed since the first request in the current batch should not exceed this threshold

When either constraint is violated by an incoming request, the current batch should be finalized and processed, and a new batch should start with the incoming request.

Function Inputs:

  • requests: A list of dictionaries, each containing:
    • 'id': Integer identifier for the request
    • 'timestamp': Float representing arrival time in seconds
    • 'features': List of float values representing input features
  • max_batch_size: Integer, maximum number of requests per batch
  • max_wait_time: Float, maximum seconds to wait before processing a batch

Function Output:

A list of tuples, where each tuple represents a processed batch containing:

  • List of request IDs in the batch
  • List of feature lists (batched features)
  • Processing timestamp (the timestamp of the last request in the batch, rounded to 4 decimals)

Note: Requests should be processed in order of their timestamps, regardless of their order in the input list.

Examples

Example 1:
Input: requests = [{'id': 1, 'timestamp': 0.0, 'features': [1.0, 2.0]}, {'id': 2, 'timestamp': 0.1, 'features': [3.0, 4.0]}, {'id': 3, 'timestamp': 0.2, 'features': [5.0, 6.0]}, {'id': 4, 'timestamp': 0.3, 'features': [7.0, 8.0]}], max_batch_size = 2, max_wait_time = 10.0
Output: [([1, 2], [[1.0, 2.0], [3.0, 4.0]], 0.1), ([3, 4], [[5.0, 6.0], [7.0, 8.0]], 0.3)]
Explanation: Requests arrive at times 0.0, 0.1, 0.2, and 0.3 seconds. With max_batch_size=2, the first two requests (ids 1 and 2) fill the first batch, which is processed at time 0.1. Similarly, requests 3 and 4 form the second batch, processed at time 0.3. The wait time constraint (10.0s) is never triggered since all requests arrive within 0.3 seconds.

Starter Code

def batch_requests(requests: list, max_batch_size: int, max_wait_time: float) -> list:
    """
    Group inference requests into batches based on size and time constraints.
    
    Args:
        requests: List of dicts with 'id', 'timestamp', 'features'
        max_batch_size: Maximum number of requests per batch
        max_wait_time: Maximum time to wait before processing a batch
    
    Returns:
        List of tuples: (request_ids, batched_features, process_time)
    """
    # Your code here
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews