Complete Implementation
import torch
import torch.nn as nn
import math
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3*d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, S, D = x.shape
qkv = self.W_qkv(x) # (B,S,3D)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head
q = q.view(B, S, self.n_heads, self.d_k)
q = q.transpose(1, 2) # (B,H,S,d_k)
k = k.view(B, S, self.n_heads, self.d_k)
k = k.transpose(1, 2)
v = v.view(B, S, self.n_heads, self.d_k)
v = v.transpose(1, 2)