The Analogy
Imagine comparing skyscrapers by height. Instead of measuring from sea level (huge numbers, hard to compare), you measure from the shortest building (small, manageable numbers). The log-sum-exp trick does the same: subtract the maximum value before exponentiating. This keeps numbers in a safe range without changing the mathematical result.
Key insight: Every time you call torch.logsumexp(), nn.CrossEntropyLoss(), or F.log_softmax(), PyTorch uses this trick internally. Without it, training any model with softmax (which is every classifier and every transformer) would crash with overflow/underflow. This one trick makes modern deep learning possible.
The Math & Code
# Problem: log(Σ e^xᵢ) overflows
# e.g., x = [1000, 1001, 999]
# e^1000 = Inf → crash!
# Solution: subtract max first
# log(Σ e^xᵢ) = M + log(Σ e^(xᵢ - M))
# where M = max(x)
x = np.array([1000, 1001, 999])
# Naive (CRASHES):
# np.log(np.sum(np.exp(x))) → inf
# Safe (log-sum-exp trick):
M = np.max(x) # 1001
result = M + np.log(np.sum(np.exp(x - M)))
# x - M = [-1, 0, -2]
# e^[-1, 0, -2] = [0.37, 1.0, 0.14]
# sum = 1.50, log = 0.41
# result = 1001.41 ✓
# PyTorch does this automatically:
import torch
torch.logsumexp(torch.tensor(x, dtype=torch.float), dim=0)