Gradient Clipping by Global Norm

Medium
MLE Interview Prep

Implement gradient clipping by global norm. Given a list of gradient arrays (representing gradients for different parameters) and a maximum norm threshold, compute the global L2 norm across all gradients. If this global norm exceeds the threshold, scale down all gradients proportionally so that the global norm equals the threshold. Return the clipped gradients maintaining the original structure.

Examples

Example 1:
Input: gradients=[[3.0, 4.0], [0.0, 0.0]], max_norm=1.0
Output: [[0.6, 0.8], [0.0, 0.0]]
Explanation: The global norm is $\sqrt{3^2 + 4^2 + 0^2 + 0^2} = \sqrt{25} = 5.0$. Since $5.0 > 1.0$, we need to clip. The scaling factor is $\frac{1.0}{5.0} = 0.2$. Each gradient is multiplied by 0.2: $[3.0 \times 0.2, 4.0 \times 0.2] = [0.6, 0.8]$ and $[0.0 \times 0.2, 0.0 \times 0.2] = [0.0, 0.0]$.

Starter Code

def clip_gradients_by_global_norm(gradients: list[list[float]], max_norm: float) -> list[list[float]]:
	"""
	Clip gradients by global norm.
	
	Args:
		gradients: List of gradient arrays
		max_norm: Maximum allowed global norm
	
	Returns:
		List of clipped gradient arrays
	"""
	# Your code here
	pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews