Implement gradient clipping by global norm. Given a list of gradient arrays (representing gradients for different parameters) and a maximum norm threshold, compute the global L2 norm across all gradients. If this global norm exceeds the threshold, scale down all gradients proportionally so that the global norm equals the threshold. Return the clipped gradients maintaining the original structure.
Examples
Example 1:
Input:
gradients=[[3.0, 4.0], [0.0, 0.0]], max_norm=1.0Output:
[[0.6, 0.8], [0.0, 0.0]]Explanation: The global norm is $\sqrt{3^2 + 4^2 + 0^2 + 0^2} = \sqrt{25} = 5.0$. Since $5.0 > 1.0$, we need to clip. The scaling factor is $\frac{1.0}{5.0} = 0.2$. Each gradient is multiplied by 0.2: $[3.0 \times 0.2, 4.0 \times 0.2] = [0.6, 0.8]$ and $[0.0 \times 0.2, 0.0 \times 0.2] = [0.0, 0.0]$.
Starter Code
def clip_gradients_by_global_norm(gradients: list[list[float]], max_norm: float) -> list[list[float]]:
"""
Clip gradients by global norm.
Args:
gradients: List of gradient arrays
max_norm: Maximum allowed global norm
Returns:
List of clipped gradient arrays
"""
# Your code here
passPython3
ReadyLines: 1Characters: 0
Ready