Ch 6 — The Training Recipe

How LLMs learn from raw text — next-token prediction, loss functions, optimizers, and data pipelines
Training
school
Objective
arrow_forward
functions
Loss
arrow_forward
arrow_downward
Backprop
arrow_forward
speed
Optimizer
arrow_forward
timeline
Schedule
arrow_forward
database
Data
arrow_forward
tune
Stability
arrow_forward
code
Loop
arrow_forward
monitoring
Eval
-
Click play or press Space to begin...
Step- / 9
school
The Training Objective: Next-Token Prediction
The deceptively simple task that teaches LLMs everything
The Analogy
Imagine a game of fill-in-the-blank, played trillions of times. Given “The capital of France is ___”, the model must predict “Paris.” Given “def fibonacci(n): return n if n <= 1 else ___”, it must predict the recursive call. By predicting the next word across trillions of sentences from books, code, Wikipedia, and the web, the model implicitly learns grammar, facts, reasoning, coding, math, and even common sense. No labels needed — the text itself is the teacher.
Key insight: Next-token prediction is a form of self-supervised learning: the training signal comes from the data itself. Every token in the training set becomes both an input (context) and a label (target). A single document of 1,000 tokens provides 999 training examples. This is why LLMs can be trained on the entire internet without human annotation — the text supervises itself.
How It Works
# Input: "The cat sat on the" # Target: "cat sat on the mat" # (shifted by one position) # At each position, the model predicts # a probability distribution over the # entire vocabulary (~100K-200K tokens) # Position 1: "The" → predict "cat" # Position 2: "The cat" → predict "sat" # Position 3: "The cat sat" → predict "on" # Position 4: "The cat sat on" → predict "the" # Position 5: "The cat sat on the" → predict "mat" # The model outputs logits for EVERY token # in the vocabulary at each position. # logits shape: (batch, seq_len, vocab_size) # e.g., (4, 2048, 128256) for Llama 3
functions
Cross-Entropy Loss: Measuring Wrongness
The single number that tells the model how to improve
The Analogy
Imagine a weather forecaster who says “90% chance of sun” and it rains. They were confidently wrong — that’s a high loss. If they said “50% sun, 50% rain,” the loss is lower (less confident, less penalty). Cross-entropy loss penalizes the model based on how much probability it assigned to the correct answer. Assign 0.01 to the right token? Huge loss. Assign 0.99? Tiny loss. The formula: L = −log(p_correct).
Key insight: Cross-entropy connects directly to information theory (MathForAI Ch 7). The loss measures how many extra bits the model needs compared to a perfect predictor. GPT-3’s training loss of ~2.7 nats means it uses ~2.7 extra bits per token. A loss of 0 would mean perfect prediction. The irreducible loss (entropy of natural language) is estimated at ~1.0-1.5 nats — no model can go below this.
Worked Example
# Model predicts next token after "The cat" # Vocabulary: [sat, ran, the, dog, ...] # True next token: "sat" (index 0) # Model outputs logits: [2.5, 1.0, 0.3, -0.5] # Softmax → probabilities: # P(sat)=0.62, P(ran)=0.14, P(the)=0.07, ... # Cross-entropy loss: # L = -log(P(correct)) = -log(0.62) = 0.48 # If model was more confident: # P(sat)=0.95 → L = -log(0.95) = 0.05 ✓ # If model was wrong: # P(sat)=0.01 → L = -log(0.01) = 4.6 ✗ import torch.nn.functional as F logits = model(input_ids) # (B, S, V) loss = F.cross_entropy( logits.view(-1, vocab_size), target_ids.view(-1) )
arrow_downward
Backpropagation: Tracing Blame
How the model figures out which parameters to adjust
The Analogy
A soccer team loses 3-0. The coach reviews the tape: the goalkeeper let in an easy shot (high blame), the striker missed chances (medium blame), the midfielder played well (low blame). Backpropagation does the same: it traces the loss backward through every layer, computing how much each parameter contributed to the error. Parameters that caused more error get larger adjustments. This uses the chain rule from calculus (MathForAI Ch 4).
Key insight: For an 8B parameter model, backprop computes 8 billion gradients — one for every parameter. This requires ~2× the compute of the forward pass (the “6ND” rule: 2N for forward, 4N for backward per token). The gradients tell us the direction to adjust each parameter to reduce loss. The optimizer decides how much to adjust.
The Chain Rule in Action
# Forward pass: compute predictions logits = model(input_ids) loss = F.cross_entropy(logits, targets) # Backward pass: compute all gradients loss.backward() # PyTorch automatically computes: # ∂L/∂w for every parameter w in the model # Chain rule example (simplified): # L = f(g(h(x))) (nested layers) # ∂L/∂x = ∂f/∂g · ∂g/∂h · ∂h/∂x # For Llama 3 8B: # - 8 billion parameters # - 8 billion gradients computed # - Flows backward through 32 layers # - Residual connections help gradients # flow without vanishing (Ch 4) # Memory: need to store activations from # forward pass for backward computation # → training uses ~3× more memory than inference
speed
AdamW: The Optimizer That Trains Every LLM
Adaptive learning rates with weight decay
The Analogy
Imagine navigating a hilly landscape in fog. Basic gradient descent is like always taking the same-sized step downhill. Adam is smarter: it remembers which direction you’ve been going (momentum) and how steep the terrain has been (adaptive step size). Flat areas? Take bigger steps. Steep, noisy areas? Take smaller, careful steps. AdamW adds “weight decay” — a gentle pull toward zero that prevents parameters from growing too large (regularization).
Key insight: AdamW stores two extra values per parameter: the first moment (mean of gradients) and second moment (mean of squared gradients). For an 8B model, that’s 8B × 2 = 16B extra values. At FP32, the optimizer states alone need 64 GB of memory — more than the model weights! This is why training requires far more memory than inference.
AdamW in Practice
# AdamW update rule (per parameter): # m = β₁·m + (1-β₁)·g (momentum) # v = β₂·v + (1-β₂)·g² (variance) # m̂ = m / (1-β₁ᵗ) (bias correct) # v̂ = v / (1-β₂ᵗ) (bias correct) # w = w - lr·(m̂/(√v̂+ε) + λ·w) # Typical hyperparameters (Llama 3): optimizer = torch.optim.AdamW( model.parameters(), lr=1.5e-4, # peak learning rate betas=(0.9, 0.95), # momentum params eps=1e-8, # numerical stability weight_decay=0.1 # regularization ) # Memory per parameter: # Weight: 2 bytes (BF16) # Gradient: 2 bytes (BF16) # Adam m: 4 bytes (FP32) # Adam v: 4 bytes (FP32) # Total: 12 bytes/param # 8B model: 96 GB just for training state!
timeline
Learning Rate Schedule: Warmup + Cosine Decay
Start slow, peak, then gradually cool down
The Analogy
Learning rate is like the speed of a car on a mountain road. At the start (warmup), you accelerate gradually — going too fast before you know the terrain causes crashes (loss spikes). At peak speed, you’re making maximum progress. Then you slowly decelerate (cosine decay) for the final stretch, making finer and finer adjustments. This warmup → peak → decay pattern is used by virtually every LLM.
Key insight: The warmup phase is critical for training stability. Without it, the randomly initialized model receives huge gradient updates that can cause irreversible divergence. Llama 3 uses 2,000 warmup steps out of ~millions total. The cosine decay ensures the model makes progressively smaller adjustments as it converges, similar to simulated annealing in optimization (MathForAI Ch 10).
The Schedule
# Learning rate schedule visualization: # # lr │ ╭────╮ # │ / ╲ # │ / ╲ # │ / ╲╲ # │ / ╲╲╲ # │/ ╲╲╲___ # └──────────────────────── step # warmup peak cosine decay # Llama 3 schedule: # Warmup: 2,000 steps (linear ramp) # Peak lr: 1.5e-4 # Decay: cosine to 1.5e-5 (10% of peak) # Total steps: ~millions def cosine_schedule(step, warmup, total, lr_max, lr_min): if step < warmup: return lr_max * step / warmup progress = (step - warmup) / (total - warmup) return lr_min + 0.5 * (lr_max - lr_min) * ( 1 + math.cos(math.pi * progress) )
database
Training Data: The Secret Ingredient
Where 15 trillion tokens come from and how they’re prepared
The Analogy
Training data is the curriculum for the model. Just as a student who only reads comic books won’t write great essays, a model trained on low-quality web scrapes won’t produce high-quality output. Modern LLM training involves massive data engineering: crawling the web, deduplicating, filtering toxic/low-quality content, mixing in curated sources (books, code, academic papers), and carefully balancing the proportions.
Key insight: Llama 3 used 15T tokens from a mix of web data (Common Crawl), code (GitHub), books, academic papers, Wikipedia, and more. The data pipeline includes: language detection, quality filtering (perplexity-based), deduplication (MinHash), PII removal, and domain-specific heuristics. Data quality is now considered more important than model architecture — it’s the biggest differentiator between models.
Data Sources & Mixing
# Typical data mix (approximate): # Web text: ~67% (Common Crawl, filtered) # Code: ~17% (GitHub, Stack Overflow) # Books: ~5% (BookCorpus, etc.) # Academic: ~5% (ArXiv, papers) # Wikipedia: ~3% (high quality) # Math/Science: ~3% (curated) # Data pipeline stages: # 1. Crawl: ~100T raw tokens # 2. Language filter: keep English + top langs # 3. Quality filter: perplexity scoring # 4. Dedup: exact + fuzzy (MinHash) # 5. Safety filter: remove toxic content # 6. PII removal: emails, phones, etc. # 7. Domain mixing: set proportions # Result: ~15T clean tokens # Key datasets: # FineWeb (HuggingFace): 15T tokens, open # The Pile (EleutherAI): 800B tokens, open # RedPajama: 1.2T tokens, open
tune
Training Stability: Preventing Catastrophe
Gradient clipping, mixed precision, and surviving loss spikes
The Analogy
Training an LLM is like flying a plane on autopilot for 78 days straight. Turbulence (bad data batches) can cause sudden jolts. Without safety systems, one bad batch could crash the entire training run — wasting millions of dollars. Gradient clipping limits how large any single update can be (like speed limiters). Mixed precision uses lower-precision numbers to save memory while keeping critical calculations in high precision.
Key insight: Loss spikes are a real danger in LLM training. Llama 3’s training report mentions multiple loss spikes that required intervention. Common causes: corrupted data batches, numerical instability in attention, or learning rate too high. The standard fix: roll back to a recent checkpoint and skip the problematic data. At $60M+ per training run, stability engineering is worth millions.
Stability Techniques
# 1. Gradient clipping: cap gradient norm torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=1.0 ) # Prevents catastrophic updates from bad batches # 2. Mixed precision training (BF16) # Forward/backward: BF16 (16-bit brain float) # Optimizer states: FP32 (32-bit float) # Loss scaling: prevents underflow in gradients scaler = torch.amp.GradScaler() with torch.amp.autocast(dtype=torch.bfloat16): loss = model(batch) scaler.scale(loss).backward() scaler.step(optimizer) # 3. Gradient accumulation # Effective batch = micro_batch × accum_steps # Llama 3: micro_batch=1, accum=1024 # Effective batch: 4M tokens per step # 4. Checkpointing: save every N steps # If loss spikes → rollback to last good ckpt
code
The Complete Training Loop
Everything together in ~30 lines of PyTorch
Simplified Training Loop
import torch, torch.nn.functional as F model = LLM(vocab=128256, d=4096, n_heads=32, d_ff=14336, n_layers=32) optimizer = torch.optim.AdamW( model.parameters(), lr=1.5e-4, betas=(0.9, 0.95), weight_decay=0.1 ) for step, batch in enumerate(dataloader): # Update learning rate lr = cosine_schedule(step, warmup=2000, total=total_steps, lr_max=1.5e-4, lr_min=1.5e-5) for pg in optimizer.param_groups: pg['lr'] = lr # Forward: predict next tokens input_ids = batch[:, :-1] targets = batch[:, 1:] logits = model(input_ids) loss = F.cross_entropy( logits.reshape(-1, 128256), targets.reshape(-1) ) # Backward: compute gradients loss.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), 1.0 ) # Update: adjust parameters optimizer.step() optimizer.zero_grad()
That’s the Core
The entire pretraining process is: predict next token → measure error → compute gradients → update weights → repeat. This loop runs billions of times. For Llama 3 8B trained on 15T tokens with batch size 4M: that’s ~3.75 million optimizer steps. Each step processes 4 million tokens and updates 8 billion parameters. The simplicity of the loop belies the engineering complexity of making it work at scale.
Real World
Student reads a sentence, covers the next word, guesses, checks answer, adjusts understanding. Repeat for every sentence in every book in the library.
In LLMs
Model sees context, predicts next token, computes cross-entropy loss, backpropagates gradients, AdamW updates 8B parameters. Repeat 3.75M times.
monitoring
Evaluation: Is the Model Learning?
Tracking loss curves, perplexity, and downstream benchmarks
What to Watch
During training, teams monitor several signals: Training loss should decrease smoothly. Validation loss (on held-out data) should track training loss — if it diverges, the model is overfitting. Perplexity (e^loss) gives an intuitive measure: a perplexity of 10 means the model is as confused as if choosing between 10 equally likely tokens. Downstream benchmarks (MMLU, HumanEval, GSM8K) test real capabilities at checkpoints.
Key insight: The loss curve is remarkably predictable thanks to scaling laws. Labs can train a small model (1B), measure its loss curve, and extrapolate what a 70B model will achieve. This is how training runs costing $100M+ are planned with confidence. If the actual loss deviates from the predicted curve, something is wrong — bad data, bug, or hardware failure.
Key Metrics
# Training metrics: # Loss: cross-entropy (lower = better) # Perplexity: e^loss (lower = better) # PPL=10 → "choosing between 10 options" # PPL=5 → "choosing between 5 options" # Typical loss values: # Random init: ~11.8 (ln(128256) ≈ 11.8) # After warmup: ~8.0 # Mid-training: ~3.0 # Final (good): ~2.0-2.5 # Theoretical min: ~1.0-1.5 # Downstream benchmarks (Llama 3 8B): # MMLU (knowledge): 66.6% # HumanEval (code): 62.2% # GSM8K (math): 79.6% # ARC-Challenge (reasoning): 78.6% # These are evaluated periodically during # training to track capability emergence