Training & Inference

You’ve seen data flow through a transformer. Now let’s understand how it learns and how it generates text in practice.

1. The Training Objective: Next-Token Prediction

The training objective for decoder-only models (GPT, LLaMA) is beautifully simple:

Given all previous tokens, predict the next one.

P(x1,x2,,xn)=i=1nP(xix1,,xi1)P(x_1, x_2, \ldots, x_n) = \prod_{i=1}^{n} P(x_i \mid x_1, \ldots, x_{i-1})

How It Works in Practice

Given the training sentence "The cat sat on the mat":

PositionInput (context)Target (predict)
1”The”cat
2”The cat”sat
3”The cat sat”on
4”The cat sat on”the
5”The cat sat on the”mat

Every position produces a prediction simultaneously. This is the transformer’s key training advantage over RNNs — all positions are trained in parallel.

The Loss Function

At each position, we compute cross-entropy loss:

Lt=logP(xt+1x1,,xt)L_t = -\log P(x_{t+1} \mid x_1, \ldots, x_t)

Total loss is the average across all positions:

L=1Tt=1TLtL = \frac{1}{T} \sum_{t=1}^{T} L_t

🤔 Quick Check
Why is transformer training faster than RNN training?

2. Teacher Forcing

During training, at each position we use the ground truth previous tokens, not the model’s own predictions.

Context usedRisk
Teacher forcingAlways real tokensNone — clean signal
Free-runningModel’s own predictionsErrors compound at every step

Why? If the model makes a mistake at position 2, all subsequent predictions would be based on a wrong context. Teacher forcing ensures every position sees the correct context during training.


3. The Training Loop

Training follows the standard deep learning cycle:

StepWhat Happens
Forward passRun all tokens through the model in parallel
Compute lossCross-entropy between predictions and targets
Backward passCompute gradients via backpropagation
Update weightsOptimizer adjusts parameters

Optimization: AdamW

Modern transformers use AdamW optimizer:

  • Adaptive learning rates per parameter
  • Momentum (smooths noisy gradients)
  • Weight decay (regularization)

Learning Rate Schedule

PhaseWhat HappensWhy
Warmup (~2000 steps)Gradually increase LR from 0Prevents early instability
PeakMaximum learning rateFastest learning
Cosine decaySmoothly decrease to near 0Fine-grained convergence
# Simplified training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        # 1. Forward pass
        logits = model.forward(batch.input_ids)
        
        # 2. Compute loss
        loss = cross_entropy(logits, batch.target_ids)
        
        # 3. Backward pass (compute gradients)
        gradients = backward(loss)
        
        # 4. Update parameters
        optimizer.step(gradients)
        
    print(f"Epoch {epoch}: loss = {loss:.4f}")

4. Inference: Autoregressive Generation

At inference time, there’s no ground truth. The model generates one token at a time, feeding each prediction back as input:

StepAction
1Run forward pass on current sequence
2Get probability distribution for next token
3Sample a token from the distribution
4Append to sequence
5Repeat until done (or max length)

Sampling Strategies

StrategyHow It WorksTradeoff
GreedyAlways pick highest probabilityDeterministic but repetitive
TemperatureDivide logits by T before softmaxLow T = focused, high T = diverse
Top-kOnly consider top k tokensRemoves low-probability nonsense
Top-p (nucleus)Consider tokens until cumulative prob ≥ pAdapts to distribution shape
🤔 Quick Check
What happens to sampling as temperature approaches 0?
def generate(model, prompt_ids, max_tokens=100, temperature=1.0):
    generated = list(prompt_ids)
    
    for _ in range(max_tokens):
        logits = model.forward(generated)
        next_logits = logits[-1]  # Last position
        probs = softmax(next_logits / temperature)
        
        next_token = sample(probs)
        generated.append(next_token)
        
        if next_token == EOS_TOKEN:
            break
    
    return generated

def sample_with_top_p(probs, p=0.9):
    """Nucleus sampling: only consider tokens in top p cumulative probability."""
    sorted_indices = np.argsort(probs)[::-1]
    cumulative = 0
    
    for i, idx in enumerate(sorted_indices):
        cumulative += probs[idx]
        if cumulative >= p:
            allowed = sorted_indices[:i + 1]
            break
    
    filtered = np.zeros_like(probs)
    filtered[allowed] = probs[allowed]
    filtered = filtered / filtered.sum()
    
    return np.random.choice(len(filtered), p=filtered)

5. Putting It All Together

Training Pipeline

StepAction
1Collect massive text corpus (books, web, code…)
2Tokenize entire corpus (BPE with ~50K vocabulary)
3Split into batches of sequences (e.g., 2048 tokens each)
4For each batch: forward → loss → backward → update
5Repeat for millions of steps
6Result: a model that can predict next tokens

Inference Pipeline

StepAction
1User types a prompt
2Tokenize the prompt
3Forward pass through all transformer layers
4Sample next token from output distribution
5Append token to sequence
6Repeat steps 3–5 until done
7Detokenize back to text
8Return response to user

The Price of Intelligence

ModelTraining ComputeEstimated CostTraining Time
GPT-2~40 petaflop-days~$50KDays
GPT-3~3,640 petaflop-days~$4.6MWeeks
LLaMA 65B~6,300 petaflop-days~$2.4M21 days on 2048 A100s
GPT-4~Unknown~$100M+ (estimated)Months

Note: These cost estimates are approximate and shift rapidly as hardware and techniques improve.

The trend: exponentially more compute, but also exponentially more capable models.

Chinchilla Scaling Laws

DeepMind’s Chinchilla paper (2022) showed that for a given compute budget, you should:

  • Train a smaller model on more data than previously thought
  • Optimal: ~20 tokens per parameter

This shifted the field toward smaller, better-trained models (LLaMA) rather than just making models bigger.


Summary

AspectTrainingInference
DirectionForward + backward passForward only
ParallelismAll positions in parallelOne token at a time
InputGround truth tokens (teacher forcing)Model’s own predictions
OutputLoss & gradientsNext token probability
SpeedSlow (backprop is expensive)Fast per token (with KV cache)
HardwareMany GPUs for weeksSingle GPU, real-time

Congratulations!

You now understand the complete transformer architecture — from individual components to training and inference. The next modules explore two major transformer variants: BERT (encoder-only, for understanding) and GPT (decoder-only, for generation). →