Implement Multi-Head Attention

Hard
MLE Interview Prep

Implement the multi-head attention mechanism, a critical component of transformer models. You need to implement three functions:

  1. compute_qkv(X, W_q, W_k, W_v): Compute Query, Key, and Value matrices by multiplying input X with weight matrices. Returns a tuple (Q, K, V) where each has the same shape as X.

  2. self_attention(Q, K, V): Compute scaled dot-product attention for a single head. Returns the attention output with the same shape as V.

  3. multi_head_attention(Q, K, V, n_heads): Split Q, K, V into multiple heads along the feature dimension, compute self-attention for each head independently, and concatenate results. Returns output with the same shape as Q.

Inputs:

  • X: Input matrix of shape (seq_len, d_model) representing a sequence of token embeddings
  • W_q, W_k, W_v: Weight matrices of shape (d_model, d_model)
  • n_heads: Number of attention heads (must evenly divide d_model)

Workflow: First call compute_qkv to get Q, K, V matrices, then pass them to multi_head_attention.

Important: Use numerically stable softmax (subtract max before exponentiating) to prevent overflow.

Examples

Example 1:
Input: X = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) # shape (2, 4): 2 tokens, 4 features W_q = W_k = W_v = np.eye(4) # Identity weights for simplicity Q, K, V = compute_qkv(X, W_q, W_k, W_v) # Step 1: Get Q, K, V result = multi_head_attention(Q, K, V, n_heads=2) # Step 2: Apply multi-head attention
Output: result.shape = (2, 4) — Two sequence positions, 4 features (2 heads × 2 features per head)
Explanation: First, compute_qkv projects the input X into Q, K, V matrices. Then multi_head_attention splits each into 2 heads (each with 2 features), computes self-attention on each head independently, and concatenates the results back to the original dimension.

Starter Code

import numpy as np
from typing import Tuple

def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute Query, Key, and Value matrices.
    
    Args:
        X: Input matrix of shape (seq_len, d_model)
        W_q, W_k, W_v: Weight matrices of shape (d_model, d_model)
    
    Returns:
        Q, K, V matrices each of shape (seq_len, d_model)
    """
    # Your code here
    pass

def self_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray) -> np.ndarray:
    """
    Compute scaled dot-product self-attention.
    
    Args:
        Q: Query matrix of shape (seq_len, d_k)
        K: Key matrix of shape (seq_len, d_k)
        V: Value matrix of shape (seq_len, d_k)
    
    Returns:
        Attention output of shape (seq_len, d_k)
    """
    # Your code here
    pass

def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, n_heads: int) -> np.ndarray:
    """
    Compute multi-head attention.
    
    Args:
        Q, K, V: Matrices of shape (seq_len, d_model)
        n_heads: Number of attention heads
    
    Returns:
        Attention output of shape (seq_len, d_model)
    """
    # Your code here
    pass
Lines: 1Characters: 0
Ready
The AI Interview - Master AI/ML Interviews