How It Works
# Token embeddings (from lookup table):
# "The" → [0.1, 0.3, -0.2, ...]
# "cat" → [0.8, -0.1, 0.5, ...]
# "sat" → [0.2, 0.7, -0.3, ...]
# Position embeddings:
# pos 0 → [0.01, 0.02, -0.01, ...]
# pos 1 → [0.03, -0.01, 0.02, ...]
# pos 2 → [-0.02, 0.04, 0.01, ...]
# Final input = token + position:
# "The" at pos 0 → [0.11, 0.32, -0.21, ...]
# "cat" at pos 1 → [0.83, -0.11, 0.52, ...]
# "sat" at pos 2 → [0.18, 0.74, -0.29, ...]
# In PyTorch:
pos_embed = nn.Embedding(max_seq_len, d_model)
x = token_embed(ids) + pos_embed(positions)