Implement a function that performs Batch Normalization on a 4D NumPy array representing a batch of feature maps in the BCHW format (batch, channels, height, width).
Training mode (training=True):
- Compute mean and variance from the current batch across batch and spatial dimensions for each channel
- Normalize using these batch statistics
- Update running statistics using momentum:
running_stat = momentum * running_stat + (1 - momentum) * batch_stat
Inference mode (training=False):
- Use the provided running mean and variance for normalization (do not compute from batch)
After normalization, apply scale (gamma) and shift (beta) parameters. Use the provided epsilon value to ensure numerical stability.
Return a tuple of (output, running_mean, running_var).
Examples
Example 1:
Input:
X = np.random.randn(2, 2, 2, 2); gamma = np.ones((1, 2, 1, 1)); beta = np.zeros((1, 2, 1, 1)); training = TrueOutput:
(normalized_output, running_mean, running_var)Explanation: In training mode, compute mean and variance from the batch across dimensions (B, H, W) for each channel. Normalize using (X - mean) / sqrt(var + epsilon), then scale by gamma and shift by beta. Update running statistics using momentum for future inference. In inference mode, use the stored running statistics instead of computing from the batch.
Starter Code
import numpy as np
def batch_normalization(
X: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
running_mean: np.ndarray = None,
running_var: np.ndarray = None,
momentum: float = 0.1,
epsilon: float = 1e-5,
training: bool = True
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Perform Batch Normalization on BCHW input.
Args:
X: Input array of shape (B, C, H, W)
gamma: Scale parameter of shape (1, C, 1, 1)
beta: Shift parameter of shape (1, C, 1, 1)
running_mean: Running mean for inference, shape (1, C, 1, 1)
running_var: Running variance for inference, shape (1, C, 1, 1)
momentum: Momentum for updating running statistics
epsilon: Small constant for numerical stability
training: If True, use batch statistics; if False, use running statistics
Returns:
Tuple of (normalized_output, updated_running_mean, updated_running_var)
"""
# Your code here
passPython3
ReadyLines: 1Characters: 0
Ready