Build Mini-GPT from Scratch với KV Cache Optimization bằng PyTorch
Bài viết này hướng dẫn bạn xây dựng mini-GPT từ đầu bằng PyTorch, bao gồm toàn bộ kiến trúc Transformer (Multi-Head Attention, LayerNorm, FeedForward...) và đặc biệt là implement KV Cache để tăng tốc quá trình sinh text.
Giới thiệu
Khi bạn chat với ChatGPT hay Claude, mỗi token được sinh ra theo kiểu autoregressive, tức là mô hình phải đọc lại toàn bộ chuỗi đầu vào để dự đoán token tiếp theo. Với chuỗi dài, điều này sẽ cực kỳ tốn kém về mặt tính toán.
KV Cache là kỹ thuật giải quyết bài toán đó: thay vì tính lại Key và Value của những token đã xử lý, ta có thể lưu chúng vào bộ nhớ và tái sử dụng. Nghe đơn giản, nhưng implement đúng cần khá nhiều thứ phải xử lý cẩn thận.
Trong bài này, tôi sẽ cùng các bạn:
- Build một mini-GPT từ đầu
- Implement KV Cache vào Multi-Head Attention
Lưu ý: Code tham khảo từ cuốn "Build a Large Language Model From Scratch" (Sebastian Raschka).
Kiến trúc tổng quan
Trước khi đi vào code, hãy nhìn qua cấu trúc các module chúng ta sẽ xây dựng:
GPTModel
├── tok_emb → Token Embedding
├── pos_emb → Positional Embedding
├── drop_emb → Embedding Dropout
├── trf_blocks[] → Danh sách TransformerBlock
│ ├── norm1 → LayerNorm
│ ├── att → MultiHeadAttention (có KV Cache)
│ ├── norm2 → LayerNorm
│ └── ff → FeedForward (GELU)
├── final_norm → LayerNorm
└── out_head → Linear → vocab logits
Config mô hình dùng cho bài này (GPT-2 small):
GPT_CONFIG_124M = {
"vocab_size": 50257,
"context_length": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 1024 # KV Cache window size
}
Các building block cơ bản
1. LayerNorm
PyTorch có sẵn nn.LayerNorm, nhưng implement thủ công giúp bạn hiểu rõ hơn:
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
scale và shift là các learnable parameter, cho phép mô hình "undo" normalization nếu cần.
2. GELU Activation
Công thức của hàm GELU như sau:
GPT dùng GELU thay vì ReLU vì nó smooth hơn và có gradient ở vùng âm:
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
Đây là xấp xỉ toán học của hàm GELU gốc, được dùng trong GPT-2.
3. FeedForward Network
Mỗi Transformer Block có một FFN với kiến trúc: Linear → GELU → Linear, với hidden dim = 4 * emb_dim:
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
Multi-Head Attention với KV Cache
Đây là phần quan trọng nhất của bài. Hãy bắt đầu từ base, rồi dần thêm KV Cache vào.
Cấu trúc cơ bản
Khởi tạo class MultiheadAttention với các parameters cơ bản như sau:
- d_in: dimension của input embedding
- d_out: dimension đầu ra của attention
- context_length: độ dài ngữ cảnh
- dropout: tỷ lệ dropout
- num_heads: số lượng attention head
- qkv_bias: có dùng bias trong Linear Layer của QKV không
- max_seq_len: độ dài tối đa cho KV cache
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads,
qkv_bias=False, max_seq_len=None, window_size=None):
super().__init__()
assert d_out % num_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
# KV Cache setup
self.max_seq_len = max_seq_len or context_length
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
register_buffer với persistent=False nghĩa là cache sẽ không được save vào checkpoint, đúng ý nghĩa của nó là bộ nhớ tạm thời khi inference.
Forward pass — phần attention scores
Trước tiên, tính Q, K, V và reshape cho multi-head:
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape
keys_new = self.W_key(x) # (b, num_tokens, d_out)
values_new = self.W_value(x)
queries = self.W_query(x)
# Split d_out thành (num_heads, head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
# Shape sau transpose: (b, num_heads, num_tokens, head_dim)
Implement KV Cache
Đây là phần thú vị nhất trong bài. KV Cache hoạt động theo cơ chế sliding window: chúng ta giữ một buffer có kích thước window_size, và liên tục append Key/Value mới vào.
if use_cache:
# Khởi tạo cache nếu chưa có hoặc batch size thay đổi
if self.cache_k is None or self.cache_k.size(0) != b:
self.cache_k = torch.zeros(
b, self.num_heads, self.window_size, self.head_dim,
device=x.device
)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0
# Nếu sắp tràn buffer → shift left (bỏ token cũ nhất)
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow
# Ghi key/value mới vào vị trí con trỏ hiện tại
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
# Đọc toàn bộ context đã cache
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
else:
keys, values = keys_new, values_new
self.ptr_cur = 0
Tại sao dùng clone() khi shift?
--> Vì nếu không clone, phép gán [:-overflow] và [overflow:] có thể chồng chéo vùng nhớ, gây ra kết quả sai.
Sao không dịch window sang phải mà là dịch cache sang trái? (đây là câu hỏi tôi thắc mắc khi lần đầu học kiến trúc này 🤦♂️)
--> Bạn suy nghĩ trả lời thử xem, tôi sẽ để đáp án ở cuối bài.
Causal Mask với Cache
Đây là điểm tinh tế: khi có cache, số cột của attn_scores (K) lớn hơn số hàng (num_tokens) — vì query chỉ là chunk mới, nhưng key là toàn bộ context cũ + mới.
K = attn_scores.size(-1)
if num_tokens == K:
# Không có cache → dùng mask tam giác tiêu chuẩn
causal_mask = torch.triu(
torch.ones(num_tokens, K, device=x.device, dtype=torch.bool),
diagonal=1
)
else:
# Có cache → phải offset diagonal
offset = K - num_tokens
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
causal_mask = row_idx + offset < col_idx # True = vị trí cần mask
Ví dụ: cache đang giữ 5 token, query là 1 token mới → offset = 5, token query này được "nhìn thấy" tất cả 5 token trước đó và chính nó, nhưng không nhìn thấy tương lai (không có tương lai trong trường hợp này).
Phần còn lại của forward:
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
Và method để reset cache khi bắt đầu inference mới:
def reset_cache(self):
self.cache_k, self.cache_v = None, None
TransformerBlock và GPTModel
TransformerBlock
Mỗi block theo kiến trúc Pre-LN (LayerNorm trước Attention, khác với paper gốc dùng Post-LN):
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"],
window_size=cfg.get("kv_window_size", cfg["context_length"])
)
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x, use_cache=False):
# Attention block với residual connection
shortcut = x
x = self.norm1(x)
x = self.att(x, use_cache=use_cache)
x = self.drop_shortcut(x)
x = x + shortcut
# FeedForward block với residual connection
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut
return x
GPTModel
Điểm khác biệt quan trọng so với GPT thông thường: Positional Embedding phải track vị trí hiện tại khi có KV Cache, vì mỗi chunk đến sẽ cần vị trí chính xác thay vì luôn bắt đầu từ 0.
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
# Dùng ModuleList thay vì Sequential để truyền use_cache vào từng block
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.ptr_current_pos = 0 # track vị trí cho positional embedding
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.kv_window_size = cfg.get("kv_window_size", cfg["context_length"])
def forward(self, in_idx, use_cache=False):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
if use_cache:
context_length = self.pos_emb.num_embeddings
assert self.ptr_current_pos + seq_len <= context_length, \
f"Position overflow: {self.ptr_current_pos + seq_len} > {context_length}"
pos_ids = torch.arange(
self.ptr_current_pos,
self.ptr_current_pos + seq_len,
device=in_idx.device, dtype=torch.long
)
self.ptr_current_pos += seq_len
else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
x = self.drop_emb(tok_embeds + pos_embeds)
for blk in self.trf_blocks:
x = blk(x, use_cache=use_cache)
x = self.final_norm(x)
return self.out_head(x)
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
Tại sao dùng nn.ModuleList thay vì nn.Sequential? --> Vì nn.Sequential chỉ forward qua từng module với 1 argument. Chúng ta cần truyền thêm use_cache, nên phải dùng ModuleList và loop thủ công.
Text Generation
Không có KV Cache (baseline)
def generate_text_simple(model, idx, max_new_tokens, context_size):
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
idx = torch.cat((idx, idx_next), dim=1)
return idx
Mỗi step: toàn bộ chuỗi được tính lại từ đầu. Với 200 token cần generate, chuỗi dài 1000 token → ~200 lần forward qua 1000 token.
Có KV Cache (optimized)
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
kv_window_size = model.kv_window_size
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
# Prefill phase: xử lý prompt ban đầu theo chunk
input_tokens = idx[:, -ctx_len:]
input_tokens_length = input_tokens.size(1)
for i in range(0, input_tokens_length, kv_window_size):
chunk = input_tokens[:, i:i+kv_window_size]
logits = model(chunk, use_cache=True)
# Giới hạn số token có thể generate (do position embedding có giới hạn)
max_new_tokens = min(max_new_tokens, ctx_len - input_tokens_length)
# Decode phase: mỗi step chỉ cần forward 1 token
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
Hai giai đoạn của KV Cache:
| Giai đoạn | Tên | Mô tả |
|---|---|---|
| Xử lý prompt | Prefill | Forward toàn bộ prompt (theo chunk nếu dài), tích lũy KV cache |
| Sinh token | Decode | Mỗi step chỉ forward 1 token, dùng lại cache |
Trong decode phase, mỗi bước chỉ tốn O(1) thay vì O(n) như trước. Đây là lý do KV Cache cực kỳ quan trọng trong production.
Running
def main():
GPT_CONFIG_124M = {
"vocab_size": 50257,
"context_length": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 1024
}
torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG_124M).to(device)
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
start_context = "Hello, I am"
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
token_ids = generate_text_simple_cached(
model=model,
idx=encoded_tensor,
max_new_tokens=200,
)
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
print("Output:", decoded_text)
Lưu ý: Vì model được khởi tạo ngẫu nhiên (không load pretrained weights), output sẽ là gibberish. Để có text có nghĩa, bạn cần load weights từ GPT-2 của OpenAI.
Kết luận
Chúng ta đã xây dựng hoàn chỉnh một mini-GPT với KV Cache optimization. Tóm lại những điểm quan trọng:
Về kiến trúc:
- GPT dùng Pre-LN (LayerNorm trước Attention/FFN), giúp training ổn định hơn
nn.ModuleListcần thiết khi muốn truyền thêm argument vào từng blockregister_buffer(persistent=False)là cách đúng để lưu cache — không bị save vào checkpoint
Về KV Cache:
- Chia làm 2 phase: Prefill (xử lý prompt) và Decode (sinh từng token)
- Sliding window giúp xử lý context dài hơn
window_sizebằng cách discard token cũ - Causal mask cần offset khi
num_tokens < K(tức là đang có cache) - Positional Embedding phải track vị trí tuyệt đối, không reset về 0 mỗi step
Bước tiếp theo nếu bạn muốn đi sâu hơn:
- Load pretrained GPT-2 weights từ OpenAI để có text có nghĩa
- Thêm temperature sampling, top-k, top-p thay vì greedy argmax
- Benchmark thực tế giữa cached và non-cached trên GPU
- Thử Flash Attention để tối ưu memory thêm nữa
PS: Trả lời cho câu hỏi ở phần Cache --> Việc dịch cache sang trái để duy trì một vùng nhớ liên tục, có kích thước cố định nhằm tính toán attention hiệu quả. Nếu ta dịch window, ta sẽ phải giữ lại toàn bộ chuỗi, điều này phá vỡ giả định bộ nhớ bị giới hạn của sliding window attention. Do đó để đạt hiệu năng cao hơn, có thể dùng buffer dạng vòng (circular buffer) để tránh việc di chuyển dữ liệu, nhưng đổi lại việc indexing sẽ phức tạp hơn.
Tham khảo
- Sebastian Raschka — Build a Large Language Model From Scratch (Manning, 2024)
- Vaswani et al. — Attention Is All You Need (2017)
- Radford et al. — Language Models are Unsupervised Multitask Learners (OpenAI GPT-2, 2019)
Nếu bài viết có ích, hãy upvote và để lại comment nếu bạn có câu hỏi hoặc gặp bug khi chạy code nhé!
All Rights Reserved