Implement the multi-head attention mechanism, a critical component of transformer models. You need to implement three functions:
-
compute_qkv(X, W_q, W_k, W_v): Compute Query, Key, and Value matrices by multiplying input X with weight matrices. Returns a tuple (Q, K, V) where each has the same shape as X. -
self_attention(Q, K, V): Compute scaled dot-product attention for a single head. Returns the attention output with the same shape as V. -
multi_head_attention(Q, K, V, n_heads): Split Q, K, V into multiple heads along the feature dimension, compute self-attention for each head independently, and concatenate results. Returns output with the same shape as Q.
Inputs:
- X: Input matrix of shape (seq_len, d_model) representing a sequence of token embeddings
- W_q, W_k, W_v: Weight matrices of shape (d_model, d_model)
- n_heads: Number of attention heads (must evenly divide d_model)
Workflow: First call compute_qkv to get Q, K, V matrices, then pass them to multi_head_attention.
Important: Use numerically stable softmax (subtract max before exponentiating) to prevent overflow.
Examples
X = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) # shape (2, 4): 2 tokens, 4 features
W_q = W_k = W_v = np.eye(4) # Identity weights for simplicity
Q, K, V = compute_qkv(X, W_q, W_k, W_v) # Step 1: Get Q, K, V
result = multi_head_attention(Q, K, V, n_heads=2) # Step 2: Apply multi-head attentionresult.shape = (2, 4) — Two sequence positions, 4 features (2 heads × 2 features per head)Starter Code
import numpy as np
from typing import Tuple
def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Compute Query, Key, and Value matrices.
Args:
X: Input matrix of shape (seq_len, d_model)
W_q, W_k, W_v: Weight matrices of shape (d_model, d_model)
Returns:
Q, K, V matrices each of shape (seq_len, d_model)
"""
# Your code here
pass
def self_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray:
"""
Compute scaled dot-product self-attention.
Args:
Q: Query matrix of shape (seq_len, d_k)
K: Key matrix of shape (seq_len, d_k)
V: Value matrix of shape (seq_len, d_k)
Returns:
Attention output of shape (seq_len, d_k)
"""
# Your code here
pass
def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, n_heads: int) -> np.ndarray:
"""
Compute multi-head attention.
Args:
Q, K, V: Matrices of shape (seq_len, d_model)
n_heads: Number of attention heads
Returns:
Attention output of shape (seq_len, d_model)
"""
# Your code here
pass