Multi-Head Attention

Why One Head Isn’t Enough

In the last module, we saw how self-attention lets tokens look at each other. But a single attention head computes one set of attention weights — one “perspective” on the relationships.

Language, however, is multi-faceted. Consider:

“The tired cat sat on the warm mat”

A reader simultaneously tracks:

  • Syntax: “cat” is the subject of “sat” (grammar)
  • Semantics: “tired” modifies “cat” (meaning)
  • Proximity: “on” connects “sat” to “mat” (structure)
  • Co-reference: “cat” and “it” would refer to the same thing

One attention head can only capture one of these patterns at a time. Multi-head attention runs multiple attention heads in parallel, each free to learn a different relationship pattern.

The Analogy

Think of it like a team of analysts reading the same document:

  • Analyst 1 (grammarian): focuses on subject-verb relationships
  • Analyst 2 (semanticist): focuses on meaning connections
  • Analyst 3 (editor): focuses on nearby words
  • Analyst 4 (detective): focuses on long-range references

Each analyst writes a summary. Then a manager combines all summaries into a final report.


The Formula

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O

where each head is a standard attention operation:

headi=Attention(XWiQ,XWiK,XWiV)\text{head}_i = \text{Attention}(X W_i^Q, \, X W_i^K, \, X W_i^V)

The Dimension Split

Here’s the clever part: we don’t create hh independent attention mechanisms. Instead, we split the model dimension across heads:

dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h

Configdmodeld_{\text{model}}Heads (hh)Per-head dkd_kTotal params
Original Transformer5128644×51224 \times 512^2
GPT-2 Small76812644×76824 \times 768^2
GPT-312288961284×1228824 \times 12288^2

The key insight: more heads doesn’t mean more parameters — it means the same capacity is divided into more specialized perspectives.

✍️ Fill in the Blanks
With d_model = 512 and 8 attention heads, each head has d_k = dimensions.

Interactive: See Multiple Heads in Action

Explore how different heads learn different patterns. Click on individual heads to see their attention matrices, or view all heads together:

Multi-Head Attention Visualizer

4 heads
All Heads (showing patterns)
H1
H2
H3
H4
Concat + Linear
Why Multiple Heads?
Each head can learn different attention patterns:
  • Positional: Track word order and proximity
  • Syntactic: Connect subjects and verbs
  • Semantic: Link related concepts
The outputs are concatenated and projected back to model dimension.
MultiHead(Q, K, V) = Concat(head1, ..., headh)WO

What to look for:

  • H1 (Positional): Attends to nearby tokens — captures local structure
  • H2 (Self): Strong self-attention — each token focuses on its own features
  • H3 (Uniform): Broad attention — gathers context from everywhere
  • H4 (Backward): Attends to earlier positions — captures left context

In real trained models, heads develop even more specialized patterns!


Parameter Count

Let’s count exactly how many parameters multi-head attention uses:

Weight MatrixShapeParameters
WQW^Q(d_model, d_model)dmodel2d_{\text{model}}^2
WKW^K(d_model, d_model)dmodel2d_{\text{model}}^2
WVW^V(d_model, d_model)dmodel2d_{\text{model}}^2
WOW^O(d_model, d_model)dmodel2d_{\text{model}}^2
Total4×dmodel24 \times d_{\text{model}}^2

For dmodel=512d_{\text{model}} = 512: 4×5122=4×262,144=1,048,5764 \times 512^2 = 4 \times 262{,}144 = \textbf{1,048,576} parameters ≈ 1M.

Note: The number of heads hh doesn’t change the total parameters — it only changes how they’re organized internally.

🤔 Quick Check
If we increase the number of heads from 8 to 16 (keeping d_model = 512), what happens to the total parameter count?

What Do Trained Heads Actually Learn?

Research on real transformer models reveals remarkable specialization:

Head PatternWhat It LearnsExample
PositionalAttend to adjacent/fixed-offset positionsToken at pos i always attends to pos i-1
SyntacticSubject-verb agreement”The dogs … are barking”
SemanticWord meaning relationships”bank” attends to “river” or “money”
DelimiterAttend to punctuation/special tokensFocus on periods, commas, [SEP]
Rare tokenAttend to infrequent wordsProper nouns get extra attention

Some heads can even be pruned (removed) without hurting performance — they learn redundant or less useful patterns.

class MultiHeadAttention:
    def __init__(self, d_model, n_heads):
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # One big projection per Q, K, V (split into heads after)
        self.W_Q = np.random.randn(d_model, d_model) * 0.01
        self.W_K = np.random.randn(d_model, d_model) * 0.01
        self.W_V = np.random.randn(d_model, d_model) * 0.01
        self.W_O = np.random.randn(d_model, d_model) * 0.01
    
    def split_heads(self, x):
        """(seq_len, d_model) → (n_heads, seq_len, d_k)"""
        seq_len = x.shape[0]
        x = x.reshape(seq_len, self.n_heads, self.d_k)
        return x.transpose(1, 0, 2)
    
    def combine_heads(self, x):
        """(n_heads, seq_len, d_k) → (seq_len, d_model)"""
        x = x.transpose(1, 0, 2)  # (seq_len, n_heads, d_k)
        return x.reshape(x.shape[0], -1)
    
    def forward(self, X, mask=None):
        # 1. Project to Q, K, V (full d_model)
        Q = X @ self.W_Q
        K = X @ self.W_K
        V = X @ self.W_V
        
        # 2. Split into heads
        Q = self.split_heads(Q)  # (n_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 3. Scaled dot-product attention PER HEAD
        scores = np.einsum('hid,hjd->hij', Q, K) / np.sqrt(self.d_k)
        if mask is not None:
            scores = scores + mask
        weights = softmax(scores, axis=-1)
        
        # 4. Weighted values per head
        attn_output = np.einsum('hij,hjd->hid', weights, V)
        
        # 5. Combine heads + output projection
        combined = self.combine_heads(attn_output)
        return combined @ self.W_O

Multi-Query Attention (MQA)

All heads share the same K and V projections. Only Q differs per head:

  • Standard: h × (Q, K, V) = 3h projections
  • MQA: h × Q + 1 × K + 1 × V = h + 2 projections

Benefit: Much faster inference (less memory for KV cache). Used in PaLM, Falcon.

Grouped Query Attention (GQA)

A compromise: heads are grouped, and each group shares K, V:

  • 8 heads with 2 groups: heads 1-4 share K₁,V₁; heads 5-8 share K₂,V₂
  • Standard MHA quality, but much faster inference

Used in: LLaMA 2, Mistral, Gemma

Why This Matters

During generation, the model caches K and V for all previous tokens. With 96 layers × 96 heads × 128 d_k:

VariantKV Cache (4K context)
Standard MHA~24 GB
GQA (8 groups)~2 GB
MQA (1 group)~0.25 GB

This is why efficient attention is critical for deploying large models.


Key Equations

Multi-Head Attention

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O

Per-Head Attention

headi=Attention(XWiQ,XWiK,XWiV)\text{head}_i = \text{Attention}(X W_i^Q, X W_i^K, X W_i^V)

Per-Head Dimensions

dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h


Summary

ConceptSingle HeadMulti-Head
Perspectives1h (typically 8-96)
Per-head dimd_modeld_model / h
Total params4 × d_model²4 × d_model² (same!)
What it capturesOne relationship typeMany relationship types

Next: Positional Encoding

Self-attention treats all positions equally — it’s permutation-invariant. We need to add position information. That’s Positional Encoding →