Layer Norm & Residuals
Why Do We Need These?
Imagine building GPT-3 with 96 layers. Without two critical tricks, your model would be completely untrainable:
- Without normalization: Activations explode to infinity or collapse to zero as data flows through dozens of layers
- Without residual connections: Gradients vanish during backpropagation — early layers receive near-zero gradients and can’t learn
These aren’t fancy optimizations — they’re essential prerequisites for any deep transformer.
Residual Connections: The Gradient Highway
The Problem
In a deep network without skip connections, the gradient must flow backward through every layer’s transformation:
If each is slightly less than 1, the product vanishes exponentially. Sound familiar? It’s the same vanishing gradient problem we saw with RNNs — but across layers instead of across time!
The Solution: Skip Connections
Instead of replacing with , we add to the original . The gradient through this operation is:
Even if is tiny, the gradient is still approximately 1! The identity connection creates a “highway” for gradients to flow back unimpeded.
See It In Action
Toggle between “With Residuals” and “Without Residuals” to see how gradient magnitude changes across layers. Try increasing the layer count to see how the problem gets worse:
Gradient Flow Through Layers
Layer Normalization
The Problem: Internal Covariate Shift
As activations flow through layers, their distribution shifts:
| Layer | Mean | Std Dev | Status |
|---|---|---|---|
| 1 | 0.0 | 1.0 | ✅ Healthy |
| 5 | 3.2 | 15.7 | ⚠️ Drifting |
| 10 | 87 | 4312 | ❌ Exploding! |
Later layers must constantly adapt to a moving target. Training becomes slow and unstable.
The Solution: Normalize Each Vector
Layer normalization normalizes each vector independently across its features:
where:
- (mean of the vector)
- (variance)
- are learnable parameters (scale and shift)
- is a small constant (e.g., ) for numerical stability
Interactive: Watch Normalization Step by Step
Click through the tabs to see how layer normalization transforms a vector. Adjust γ and β to see the effect of the learned scale and shift:
Layer Normalization Step by Step
class LayerNorm:
def __init__(self, d_model, eps=1e-6):
self.gamma = np.ones(d_model) # Learnable scale
self.beta = np.zeros(d_model) # Learnable shift
self.eps = eps
def forward(self, x):
mean = x.mean(axis=-1, keepdims=True)
var = x.var(axis=-1, keepdims=True)
x_norm = (x - mean) / np.sqrt(var + self.eps)
return self.gamma * x_norm + self.betaWhy Not Batch Normalization?
You might know Batch Normalization from CNNs. It normalizes across the batch dimension. But for transformers:
| Aspect | BatchNorm | LayerNorm |
|---|---|---|
| Normalizes across | Batch (examples) | Features (dimensions) |
| Requires | Large batch size | Works with any batch size |
| At inference | Needs running statistics | No extra state needed |
| For sequences | Problematic (variable length) | Works naturally |
LayerNorm is simpler and works better for transformers.
Putting It Together: The Transformer Block
Every transformer block uses both tricks. The standard pattern (Pre-LN):
Pre-LN vs Post-LN
| Pre-LN (modern) | Post-LN (original) | |
|---|---|---|
| Order | LN → Sublayer → Add | Sublayer → Add → LN |
| Stability | More stable | Less stable |
| Warmup needed | Optional | Critical |
| Used by | GPT-2, GPT-3, LLaMA | Original Transformer, BERT |
class TransformerBlock:
def __init__(self, d_model, n_heads, d_ff):
self.attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model, d_ff)
self.ln1 = LayerNorm(d_model)
self.ln2 = LayerNorm(d_model)
def forward(self, x, mask=None):
# Sub-layer 1: Attention with residual + norm
x = x + self.attn(self.ln1(x), mask)
# Sub-layer 2: FFN with residual + norm
x = x + self.ffn(self.ln2(x))
return xFormula
RMSNorm skips the mean subtraction — it only normalizes by the root-mean-square. No bias parameter either.
Why use it? About 10% faster than LayerNorm (one fewer operation), and empirically works just as well.
Used by: LLaMA, LLaMA 2, Gemma, Mistral
class RMSNorm:
def __init__(self, d_model, eps=1e-6):
self.gamma = np.ones(d_model)
self.eps = eps
def forward(self, x):
rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + self.eps)
return self.gamma * x / rmsSummary
| Technique | What It Does | Why It Matters |
|---|---|---|
| Residual Connection | Gradient highway — prevents vanishing gradients | |
| Layer Normalization | Normalize each vector to mean 0, std 1 | Prevents activation explosion, stabilizes training |
| Together | Used in every transformer block | Makes 96+ layer networks trainable |
Key Equations
Residual Connection
Layer Normalization
Pre-LN Transformer Block
Next: The Full Transformer
You now know every component. It’s time to put them all together into the complete transformer architecture →