Ch 3 — Attention: The Core Innovation

How every token decides who to listen to — the mechanism that made transformers possible
Foundation
help
Why
arrow_forward
key
Q K V
arrow_forward
calculate
Scores
arrow_forward
functions
Softmax
arrow_forward
hub
Multi-Head
arrow_forward
visibility_off
Causal Mask
arrow_forward
code
Code
arrow_forward
visibility
Patterns
-
Click play or press Space to begin...
Step- / 8
help
Why Attention? The Problem It Solves
Words need context from other words to have meaning
The Analogy
Imagine you’re at a dinner party with 10 people talking. When someone says “it,” you automatically know they’re referring to the topic from 3 sentences ago. Your brain “attends” to the relevant earlier words. Self-attention gives the model this same ability: every token can look at every other token and decide which ones are relevant to understanding its own meaning.
Key insight: Before transformers, models processed text sequentially (RNNs) or with fixed windows (CNNs). A word at position 500 could barely “see” a word at position 1. Attention lets every word see every other word directly, regardless of distance. This is why the 2017 paper was titled “Attention Is All You Need” (Vaswani et al.).
The Problem
# "The cat sat on the mat because it was tired" # What does "it" refer to? # → "the cat" (not "the mat") # "The trophy didn't fit in the suitcase # because it was too big" # What does "it" refer to? # → "the trophy" (it was too big) # "The trophy didn't fit in the suitcase # because it was too small" # Now "it" = "the suitcase"! # Attention lets the model figure this out # by looking at ALL tokens simultaneously
Real World
At a dinner party, you focus on the most relevant speaker for each topic
In LLMs
Each token computes how much to “attend” to every other token
key
Query, Key, Value: The Three Roles
Every token plays three roles simultaneously
The Analogy
Think of a library search. You have a Query (“I want books about cooking”). Each book has a Key (its title/tags: “Italian Recipes”). You match your Query against all Keys to find relevant books. Then you read the Value (the actual content of the matching books). In attention, every token generates all three: Q (what am I looking for?), K (what do I contain?), V (what information do I provide?).
Key insight: Q, K, and V are created by multiplying the token’s embedding by three different learned weight matrices: W_Q, W_K, W_V. These matrices are the main learnable parameters in attention. The model learns what to ask (Q), how to describe itself (K), and what to share (V) through training.
The Math
# Input: token embedding x (shape: d_model) # Three learned weight matrices: # W_Q: (d_model, d_k) "what am I looking for?" # W_K: (d_model, d_k) "what do I contain?" # W_V: (d_model, d_v) "what info do I share?" Q = x @ W_Q # Query vector K = x @ W_K # Key vector V = x @ W_V # Value vector # For GPT-3 (d_model=12288, 96 heads): # d_k = d_v = d_model / n_heads # d_k = 12288 / 96 = 128 per head # For Llama 3 8B (d_model=4096, 32 heads): # d_k = 4096 / 32 = 128 per head
calculate
Attention Scores: Who Matches Whom?
Dot product measures how relevant each token is to every other
The Analogy
At the dinner party, you subconsciously rate how relevant each speaker is to your current thought. “Very relevant” = high score. “Completely off-topic” = low score. The dot product between a Query and each Key gives this relevance score. High dot product = the Key matches the Query well. The scores are then divided by √d_k to prevent them from getting too large (which would make softmax too “spiky”).
Key insight: The scaling factor √d_k is critical. Without it, when d_k is large (e.g., 128), dot products can be huge, pushing softmax into regions where gradients vanish. Dividing by √128 ≈ 11.3 keeps values in a reasonable range. This is why it’s called “scaled dot-product attention.”
Worked Example
# Sentence: "The cat sat" # 3 tokens, each with Q, K, V vectors # Step 1: Compute scores = Q × K^T # For token "sat" (its Query vs all Keys): # score("sat","The") = Q_sat · K_The = 2.1 # score("sat","cat") = Q_sat · K_cat = 8.7 # score("sat","sat") = Q_sat · K_sat = 3.2 # Step 2: Scale by √d_k # d_k = 128, √128 ≈ 11.3 # scaled: [2.1/11.3, 8.7/11.3, 3.2/11.3] # = [0.19, 0.77, 0.28] # Step 3: Softmax → probabilities # weights = softmax([0.19, 0.77, 0.28]) # = [0.22, 0.48, 0.30] # "sat" attends most to "cat" (0.48) # Step 4: Weighted sum of Values # output = 0.22·V_The + 0.48·V_cat + 0.30·V_sat
functions
The Complete Attention Formula
Attention(Q, K, V) = softmax(QK¹/√d_k) · V
The Full Picture
The entire self-attention mechanism in one formula: Attention(Q, K, V) = softmax(QKT/√d_k) · V. For a sequence of n tokens with embedding dimension d: Q and K are (n × d_k), so QKT is (n × n) — the attention matrix. Each row is one token’s attention weights over all tokens. Multiply by V (n × d_v) to get the output (n × d_v).
Key insight: The attention matrix is (n × n), meaning compute scales quadratically with sequence length. A 128K context window means a 128K × 128K = 16 billion element matrix per head per layer. This is why long-context models are so expensive, and why Flash Attention (Ch 11) was such a breakthrough.
Tensor Shapes
# Full attention computation: # Input X: (batch, seq_len, d_model) # Example: (32, 2048, 4096) # Project to Q, K, V: Q = X @ W_Q # (32, 2048, d_k=128) K = X @ W_K # (32, 2048, 128) V = X @ W_V # (32, 2048, 128) # Attention scores: scores = Q @ K.transpose(-2, -1) # (32, 2048, 128) @ (32, 128, 2048) # = (32, 2048, 2048) ← attention matrix! # Scale and softmax: scores = scores / math.sqrt(d_k) weights = torch.softmax(scores, dim=-1) # Weighted sum: output = weights @ V # (32, 2048, 2048) @ (32, 2048, 128) # = (32, 2048, 128)
hub
Multi-Head Attention: Multiple Perspectives
Different heads learn different types of relationships
The Analogy
Imagine watching a movie with multiple critics. One focuses on plot, another on cinematography, another on acting. Each critic (head) watches the same movie but notices different things. Multi-head attention runs multiple attention operations in parallel, each with its own W_Q, W_K, W_V. One head might learn syntax, another semantics, another coreference.
Key insight: GPT-3 has 96 attention heads. Each head operates on d_k = 12288/96 = 128 dimensions. The heads are independent — they learn different attention patterns. After all heads compute their outputs, the results are concatenated and projected back to d_model. Total compute is the same as single-head attention with full d_model, but the model gets 96 different “views.”
How It Works
# Multi-head attention: # d_model = 4096, n_heads = 32 # d_k = 4096 / 32 = 128 per head # Each head has its own W_Q, W_K, W_V: # head_1 = Attention(Q₁, K₁, V₁) # head_2 = Attention(Q₂, K₂, V₂) # ... # head_32 = Attention(Q₃₂, K₃₂, V₃₂) # Concatenate all heads: # concat = [head_1; head_2; ...; head_32] # Shape: (seq_len, 32 × 128) = (seq_len, 4096) # Final projection: # output = concat @ W_O # W_O: (4096, 4096) # Real head counts: # BERT-base: 12 heads, d_k=64 # GPT-2: 12 heads, d_k=64 # GPT-3: 96 heads, d_k=128 # Llama 3 8B: 32 heads, d_k=128 # Llama 3 70B: 64 heads, d_k=128
visibility_off
Causal Masking: No Peeking Ahead
How GPT-style models prevent tokens from seeing the future
The Analogy
Imagine writing an exam where you can only see questions you’ve already answered — you can’t peek at future questions. Causal masking (also called “masked self-attention”) enforces this rule: when predicting the next token, the model can only attend to previous tokens, never future ones. This is essential for autoregressive generation — you can’t use the answer to generate the answer.
Key insight: BERT uses bidirectional attention (sees all tokens). GPT uses causal attention (sees only past tokens). This is the fundamental architectural difference. BERT is great for understanding (classification, NER). GPT is great for generation (writing, coding, chat). All modern chat LLMs (GPT-4, Claude, Llama) use causal masking.
The Mask
# Causal mask for "The cat sat": # The cat sat # The [ 1 0 0 ] ← sees only itself # cat [ 1 1 0 ] ← sees The, cat # sat [ 1 1 1 ] ← sees all three # 1 = can attend, 0 = masked (set to -∞) # In code: import torch seq_len = 3 mask = torch.tril(torch.ones(seq_len, seq_len)) # tensor([[1, 0, 0], # [1, 1, 0], # [1, 1, 1]]) # Apply mask before softmax: scores = scores.masked_fill(mask == 0, float('-inf')) # -inf → softmax → 0 (zero attention) weights = torch.softmax(scores, dim=-1)
Bidirectional (BERT)
Every token sees all tokens. Good for understanding, can’t generate.
Causal (GPT)
Tokens only see the past. Enables autoregressive generation.
code
Self-Attention in PyTorch
The complete implementation in ~20 lines
Complete Implementation
import torch import torch.nn as nn import math class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.n_heads = n_heads self.d_k = d_model // n_heads self.W_qkv = nn.Linear(d_model, 3*d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): B, S, D = x.shape qkv = self.W_qkv(x) # (B,S,3D) q, k, v = qkv.chunk(3, dim=-1) # Reshape for multi-head q = q.view(B, S, self.n_heads, self.d_k) q = q.transpose(1, 2) # (B,H,S,d_k) k = k.view(B, S, self.n_heads, self.d_k) k = k.transpose(1, 2) v = v.view(B, S, self.n_heads, self.d_k) v = v.transpose(1, 2)
Continued
# Scaled dot-product attention scores = q @ k.transpose(-2, -1) scores = scores / math.sqrt(self.d_k) # Causal mask mask = torch.tril(torch.ones(S, S)) scores = scores.masked_fill( mask == 0, float('-inf') ) weights = torch.softmax(scores, dim=-1) # Weighted sum of values out = weights @ v # (B,H,S,d_k) # Concatenate heads out = out.transpose(1, 2).contiguous() out = out.view(B, S, -1) # (B,S,D) return self.W_o(out) # Usage: attn = CausalSelfAttention(d_model=4096, n_heads=32) x = torch.randn(1, 128, 4096) out = attn(x) # (1, 128, 4096)
visibility
What Attention Heads Actually Learn
Researchers have visualized what different heads focus on
Discovered Patterns
Researchers (Clark et al., 2019; Olsson et al., 2022) have found that attention heads specialize in specific tasks: Positional heads attend to the previous or next token. Syntactic heads attend to the verb’s subject or an adjective’s noun. Induction heads copy patterns (“if A B ... A” → predict B). Coreference heads link pronouns to their referents (“it” → “cat”).
The complete picture: Self-attention is the heart of every transformer. It lets every token see every other token, weighted by relevance. Multi-head attention gives the model multiple perspectives. Causal masking enables generation. The formula is elegant: softmax(QKT/√d_k) · V. Everything else in the transformer — feed-forward layers, normalization, residual connections — supports and refines what attention discovers.
Induction Heads
# Induction heads (Olsson et al., 2022): # The most important attention pattern # Pattern: "... A B ... A" → predict B # Example: # "Harry Potter is a wizard. Harry" → "Potter" # How it works (two heads cooperate): # Head 1 (previous token head): # Current "Harry" attends to first "Harry" # Finds the previous occurrence # Head 2 (induction head): # Looks at what followed first "Harry" # Finds "Potter" → predicts "Potter" # This emerges during training around # the same time as a sharp drop in loss # Called a "phase change" in training
Real World
A detective connecting clues from different parts of a case file
In LLMs
Attention heads connecting related tokens across thousands of positions