Training & Inference
You’ve seen data flow through a transformer. Now let’s understand how it learns and how it generates text in practice.
1. The Training Objective: Next-Token Prediction
The training objective for decoder-only models (GPT, LLaMA) is beautifully simple:
Given all previous tokens, predict the next one.
How It Works in Practice
Given the training sentence "The cat sat on the mat":
| Position | Input (context) | Target (predict) |
|---|---|---|
| 1 | ”The” | cat |
| 2 | ”The cat” | sat |
| 3 | ”The cat sat” | on |
| 4 | ”The cat sat on” | the |
| 5 | ”The cat sat on the” | mat |
Every position produces a prediction simultaneously. This is the transformer’s key training advantage over RNNs — all positions are trained in parallel.
The Loss Function
At each position, we compute cross-entropy loss:
Total loss is the average across all positions:
2. Teacher Forcing
During training, at each position we use the ground truth previous tokens, not the model’s own predictions.
| Context used | Risk | |
|---|---|---|
| Teacher forcing ✅ | Always real tokens | None — clean signal |
| Free-running ❌ | Model’s own predictions | Errors compound at every step |
Why? If the model makes a mistake at position 2, all subsequent predictions would be based on a wrong context. Teacher forcing ensures every position sees the correct context during training.
3. The Training Loop
Training follows the standard deep learning cycle:
| Step | What Happens |
|---|---|
| Forward pass | Run all tokens through the model in parallel |
| Compute loss | Cross-entropy between predictions and targets |
| Backward pass | Compute gradients via backpropagation |
| Update weights | Optimizer adjusts parameters |
Optimization: AdamW
Modern transformers use AdamW optimizer:
- Adaptive learning rates per parameter
- Momentum (smooths noisy gradients)
- Weight decay (regularization)
Learning Rate Schedule
| Phase | What Happens | Why |
|---|---|---|
| Warmup (~2000 steps) | Gradually increase LR from 0 | Prevents early instability |
| Peak | Maximum learning rate | Fastest learning |
| Cosine decay | Smoothly decrease to near 0 | Fine-grained convergence |
# Simplified training loop
for epoch in range(num_epochs):
for batch in data_loader:
# 1. Forward pass
logits = model.forward(batch.input_ids)
# 2. Compute loss
loss = cross_entropy(logits, batch.target_ids)
# 3. Backward pass (compute gradients)
gradients = backward(loss)
# 4. Update parameters
optimizer.step(gradients)
print(f"Epoch {epoch}: loss = {loss:.4f}")4. Inference: Autoregressive Generation
At inference time, there’s no ground truth. The model generates one token at a time, feeding each prediction back as input:
| Step | Action |
|---|---|
| 1 | Run forward pass on current sequence |
| 2 | Get probability distribution for next token |
| 3 | Sample a token from the distribution |
| 4 | Append to sequence |
| 5 | Repeat until done (or max length) |
Sampling Strategies
| Strategy | How It Works | Tradeoff |
|---|---|---|
| Greedy | Always pick highest probability | Deterministic but repetitive |
| Temperature | Divide logits by T before softmax | Low T = focused, high T = diverse |
| Top-k | Only consider top k tokens | Removes low-probability nonsense |
| Top-p (nucleus) | Consider tokens until cumulative prob ≥ p | Adapts to distribution shape |
def generate(model, prompt_ids, max_tokens=100, temperature=1.0):
generated = list(prompt_ids)
for _ in range(max_tokens):
logits = model.forward(generated)
next_logits = logits[-1] # Last position
probs = softmax(next_logits / temperature)
next_token = sample(probs)
generated.append(next_token)
if next_token == EOS_TOKEN:
break
return generated
def sample_with_top_p(probs, p=0.9):
"""Nucleus sampling: only consider tokens in top p cumulative probability."""
sorted_indices = np.argsort(probs)[::-1]
cumulative = 0
for i, idx in enumerate(sorted_indices):
cumulative += probs[idx]
if cumulative >= p:
allowed = sorted_indices[:i + 1]
break
filtered = np.zeros_like(probs)
filtered[allowed] = probs[allowed]
filtered = filtered / filtered.sum()
return np.random.choice(len(filtered), p=filtered)5. Putting It All Together
Training Pipeline
| Step | Action |
|---|---|
| 1 | Collect massive text corpus (books, web, code…) |
| 2 | Tokenize entire corpus (BPE with ~50K vocabulary) |
| 3 | Split into batches of sequences (e.g., 2048 tokens each) |
| 4 | For each batch: forward → loss → backward → update |
| 5 | Repeat for millions of steps |
| 6 | Result: a model that can predict next tokens |
Inference Pipeline
| Step | Action |
|---|---|
| 1 | User types a prompt |
| 2 | Tokenize the prompt |
| 3 | Forward pass through all transformer layers |
| 4 | Sample next token from output distribution |
| 5 | Append token to sequence |
| 6 | Repeat steps 3–5 until done |
| 7 | Detokenize back to text |
| 8 | Return response to user |
The Price of Intelligence
| Model | Training Compute | Estimated Cost | Training Time |
|---|---|---|---|
| GPT-2 | ~40 petaflop-days | ~$50K | Days |
| GPT-3 | ~3,640 petaflop-days | ~$4.6M | Weeks |
| LLaMA 65B | ~6,300 petaflop-days | ~$2.4M | 21 days on 2048 A100s |
| GPT-4 | ~Unknown | ~$100M+ (estimated) | Months |
Note: These cost estimates are approximate and shift rapidly as hardware and techniques improve.
The trend: exponentially more compute, but also exponentially more capable models.
Chinchilla Scaling Laws
DeepMind’s Chinchilla paper (2022) showed that for a given compute budget, you should:
- Train a smaller model on more data than previously thought
- Optimal: ~20 tokens per parameter
This shifted the field toward smaller, better-trained models (LLaMA) rather than just making models bigger.
Summary
| Aspect | Training | Inference |
|---|---|---|
| Direction | Forward + backward pass | Forward only |
| Parallelism | All positions in parallel | One token at a time |
| Input | Ground truth tokens (teacher forcing) | Model’s own predictions |
| Output | Loss & gradients | Next token probability |
| Speed | Slow (backprop is expensive) | Fast per token (with KV cache) |
| Hardware | Many GPUs for weeks | Single GPU, real-time |
Congratulations!
You now understand the complete transformer architecture — from individual components to training and inference. The next modules explore two major transformer variants: BERT (encoder-only, for understanding) and GPT (decoder-only, for generation). →