Multi-Head Attention
Why One Head Isn’t Enough
In the last module, we saw how self-attention lets tokens look at each other. But a single attention head computes one set of attention weights — one “perspective” on the relationships.
Language, however, is multi-faceted. Consider:
“The tired cat sat on the warm mat”
A reader simultaneously tracks:
- Syntax: “cat” is the subject of “sat” (grammar)
- Semantics: “tired” modifies “cat” (meaning)
- Proximity: “on” connects “sat” to “mat” (structure)
- Co-reference: “cat” and “it” would refer to the same thing
One attention head can only capture one of these patterns at a time. Multi-head attention runs multiple attention heads in parallel, each free to learn a different relationship pattern.
The Analogy
Think of it like a team of analysts reading the same document:
- Analyst 1 (grammarian): focuses on subject-verb relationships
- Analyst 2 (semanticist): focuses on meaning connections
- Analyst 3 (editor): focuses on nearby words
- Analyst 4 (detective): focuses on long-range references
Each analyst writes a summary. Then a manager combines all summaries into a final report.
The Formula
where each head is a standard attention operation:
The Dimension Split
Here’s the clever part: we don’t create independent attention mechanisms. Instead, we split the model dimension across heads:
| Config | Heads () | Per-head | Total params | |
|---|---|---|---|---|
| Original Transformer | 512 | 8 | 64 | |
| GPT-2 Small | 768 | 12 | 64 | |
| GPT-3 | 12288 | 96 | 128 |
The key insight: more heads doesn’t mean more parameters — it means the same capacity is divided into more specialized perspectives.
Interactive: See Multiple Heads in Action
Explore how different heads learn different patterns. Click on individual heads to see their attention matrices, or view all heads together:
Multi-Head Attention Visualizer
- Positional: Track word order and proximity
- Syntactic: Connect subjects and verbs
- Semantic: Link related concepts
What to look for:
- H1 (Positional): Attends to nearby tokens — captures local structure
- H2 (Self): Strong self-attention — each token focuses on its own features
- H3 (Uniform): Broad attention — gathers context from everywhere
- H4 (Backward): Attends to earlier positions — captures left context
In real trained models, heads develop even more specialized patterns!
Parameter Count
Let’s count exactly how many parameters multi-head attention uses:
| Weight Matrix | Shape | Parameters |
|---|---|---|
| (d_model, d_model) | ||
| (d_model, d_model) | ||
| (d_model, d_model) | ||
| (d_model, d_model) | ||
| Total |
For : parameters ≈ 1M.
Note: The number of heads doesn’t change the total parameters — it only changes how they’re organized internally.
What Do Trained Heads Actually Learn?
Research on real transformer models reveals remarkable specialization:
| Head Pattern | What It Learns | Example |
|---|---|---|
| Positional | Attend to adjacent/fixed-offset positions | Token at pos i always attends to pos i-1 |
| Syntactic | Subject-verb agreement | ”The dogs … are barking” |
| Semantic | Word meaning relationships | ”bank” attends to “river” or “money” |
| Delimiter | Attend to punctuation/special tokens | Focus on periods, commas, [SEP] |
| Rare token | Attend to infrequent words | Proper nouns get extra attention |
Some heads can even be pruned (removed) without hurting performance — they learn redundant or less useful patterns.
class MultiHeadAttention:
def __init__(self, d_model, n_heads):
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_k = d_model // n_heads
# One big projection per Q, K, V (split into heads after)
self.W_Q = np.random.randn(d_model, d_model) * 0.01
self.W_K = np.random.randn(d_model, d_model) * 0.01
self.W_V = np.random.randn(d_model, d_model) * 0.01
self.W_O = np.random.randn(d_model, d_model) * 0.01
def split_heads(self, x):
"""(seq_len, d_model) → (n_heads, seq_len, d_k)"""
seq_len = x.shape[0]
x = x.reshape(seq_len, self.n_heads, self.d_k)
return x.transpose(1, 0, 2)
def combine_heads(self, x):
"""(n_heads, seq_len, d_k) → (seq_len, d_model)"""
x = x.transpose(1, 0, 2) # (seq_len, n_heads, d_k)
return x.reshape(x.shape[0], -1)
def forward(self, X, mask=None):
# 1. Project to Q, K, V (full d_model)
Q = X @ self.W_Q
K = X @ self.W_K
V = X @ self.W_V
# 2. Split into heads
Q = self.split_heads(Q) # (n_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. Scaled dot-product attention PER HEAD
scores = np.einsum('hid,hjd->hij', Q, K) / np.sqrt(self.d_k)
if mask is not None:
scores = scores + mask
weights = softmax(scores, axis=-1)
# 4. Weighted values per head
attn_output = np.einsum('hij,hjd->hid', weights, V)
# 5. Combine heads + output projection
combined = self.combine_heads(attn_output)
return combined @ self.W_OMulti-Query Attention (MQA)
All heads share the same K and V projections. Only Q differs per head:
- Standard: h × (Q, K, V) = 3h projections
- MQA: h × Q + 1 × K + 1 × V = h + 2 projections
Benefit: Much faster inference (less memory for KV cache). Used in PaLM, Falcon.
Grouped Query Attention (GQA)
A compromise: heads are grouped, and each group shares K, V:
- 8 heads with 2 groups: heads 1-4 share K₁,V₁; heads 5-8 share K₂,V₂
- Standard MHA quality, but much faster inference
Used in: LLaMA 2, Mistral, Gemma
Why This Matters
During generation, the model caches K and V for all previous tokens. With 96 layers × 96 heads × 128 d_k:
| Variant | KV Cache (4K context) |
|---|---|
| Standard MHA | ~24 GB |
| GQA (8 groups) | ~2 GB |
| MQA (1 group) | ~0.25 GB |
This is why efficient attention is critical for deploying large models.
Key Equations
Multi-Head Attention
Per-Head Attention
Per-Head Dimensions
Summary
| Concept | Single Head | Multi-Head |
|---|---|---|
| Perspectives | 1 | h (typically 8-96) |
| Per-head dim | d_model | d_model / h |
| Total params | 4 × d_model² | 4 × d_model² (same!) |
| What it captures | One relationship type | Many relationship types |
Next: Positional Encoding
Self-attention treats all positions equally — it’s permutation-invariant. We need to add position information. That’s Positional Encoding →