Multi-Head Attention — Nhìn cùng lúc nhiều góc khác nhau
Thay vì một góc nhìn, Multi-Head Attention chia không gian embedding thành nhiều subspace song song, giúp model bắt pattern từ ngữ nghĩa đến cú pháp chỉ trong một lớp duy nhất.
Transformer không chỉ "nhìn" một cách duy nhất. Giống như khi phân tích một câu văn, bạn vừa để ý ngữ pháp, vừa theo dõi ý nghĩa từng từ, vừa nhớ ngữ cảnh xa — Multi-Head Attention cho phép mô hình làm điều đó bằng cách chia nhỏ không gian tính toán thành nhiều "chuyên gia" chạy song song. Đây là lý do tại sao GPT có thể viết code và BERT có thể hiểu ngữ cảnh tinh tế đến vậy.
Vấn đề
Nếu chỉ dùng Self-Attention đơn thuần với một bộ ma trận Q, K, V duy nhất, model sẽ bị giới hạn nghiêm trọng: nó chỉ học được một kiểu quan hệ tại một thời điểm. Nhưng ngôn ngữ tự nhiên có nhiều lớp thông tin đan xen — có quan hệ cú pháp (chủ ngữ kết nối với động từ), quan hệ ngữ nghĩa (đại từ nhân xưng thay thế cho danh từ nào), và quan hệ khoảng cách (từ đầu đoạn liên kết với từ cuối đoạn).
Một "đầu" Attention duy nhất cố gắng bắt tất cả các loại quan hệ này cùng lúc giống như cố gắng bắt tất cả bóng bay bằng một tay — bạn sẽ bỏ sót những pattern quan trọng hoặc trộn lẫn chúng thành mớ hỗn độn. Model cần khả năng "đa nhiệm" ngay trong một lớp: vừa nhìn xa để hiểu ngữ cảnh, vừa nhìn gần để bắt chi tiết ngữ pháp.
Ý tưởng cốt lõi
Thay vì một bộ Q, K, V, hãy dùng h "đầu" (heads) chạy song song. Mỗi đầu là một bộ projection Q, K, V riêng biệt, học trên subspace nhỏ hơn (thường là d_k = d_model / h). Chúng tính Attention độc lập, mỗi người săn một loại pattern khác nhau, rồi kết quả được nối lại (concatenate) và trộn qua ma trận W^O.
Hãy tưởng tượng bạn xem một bức tranh phức tạp với kính 3D: một mắt thấy màu sắc, một mắt thấy chiều sâu, một mắt thấy chuyển động. Hoặc như một nhóm biên tập viên cùng đọc một đoạn văn — người chuyên check ngữ pháp, người chuyên check logic lập luận, người chuyên check phong cách viết. Không ai thay thế ai, nhưng cả nhóm cùng cho ra cái nhìn toàn diện mà một người không thể có.
Nhiều người nghĩ Multi-Head làm Attention phức tạp hơn — thực ra đơn giản hơn theo cách đẹp đẽ. Đó chỉ là việc chạy h lần attention đơn giản (scaled dot-product) trên các phần nhỏ của vector embedding, rồi ghép lại. Điều quan trọng là các head này tự phân công vai trò trong quá trình training: một số head sẽ chuyên theo dõi đại từ (như "it" trỏ về "dog"), một số head chuyên phát hiện quan hệ từ vựng (từ đồng nghĩa), một số khác chuyên bắt pattern cú pháp xa (liên kết giữa hai mệnh đề cách nhau cả đoạn văn). That's it — chỉ vậy thôi, nhưng hiệu quả gấp bội.
Tại sao nó hoạt động
Bản chất toán học là phân rã không gian vector. Nếu embedding có kích thước 512, thay vì nhét tất cả thông tin vào một ma trận Attention 512×512, ta chia thành 8 head × 64 chiều. Mỗi head học một "perspective" khác nhau trong không gian con riêng.
# Pseudo-code PyTorch cho Multi-Head Attention
def multi_head_attention(X, h=8):
# X: (batch, seq_len, d_model)
d_k = d_model // h
# Linear projections cho tất cả head cùng lúc
Q = X @ W_q # (batch, seq, d_model) -> (batch, seq, d_model)
K = X @ W_k
V = X @ W_v
# Reshape thành nhiều head: (batch, h, seq, d_k)
Q = Q.view(batch, seq, h, d_k).transpose(1, 2)
# Attention score: (batch, h, seq, seq)
scores = (Q @ K.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores) @ V # (batch, h, seq, d_k)
# Concat các head lại: (batch, seq, d_model)
concat = attn.transpose(1, 2).reshape(batch, seq, d_model)
return concat @ W_o # Final linear projectionTại sao phải chia nhỏ? Vì softmax trên toàn bộ 512 chiều dễ bị chi phối bởi một vài chiều lớn (outliers), làm mất thông tin ở các chiều nhỏ. Khi chia thành 8 head × 64 chiều, mỗi head có thể "zoom in" vào tín hiệu riêng. Hệ số chia √d_k giữ cho phương sai của dot-product ổn định, tránh gradient bị nổ.
Điểm then chốt là ma trận W^O (output projection) ở cuối: nó không chỉ nối các head lại mà còn trộn chúng, cho phép thông tin từ head "cú pháp" kết hợp với head "ngữ nghĩa" để tạo representation cuối cùng. Đây là cách model "tổng hợp ý kiến chuyên gia" — mỗi head vote một kiểu, W^O tính trung bình có trọng số.
Ý nghĩa thực tế
Impact thực tế: Multi-Head Attention là tiêu chuẩn trong mọi kiến trúc Transformer hiện đại. BERT-base dùng 12 heads, GPT-3 dùng 96 heads xếp chồng qua 96 lớp. Không có nó, các model này sẽ bị mù mờ trước những quan hệ ngôn ngữ tinh vi.
Benchmarks: Các thí nghiệm ablation cho thấy giảm số head từ 8 xuống 1 khiến perplexity tăng vọt và khả năng coreference resolution (tìm đại từ thay thế) giảm 30-40%. Tuy nhiên, thú vị là sau khi train xong, một số head có thể bị "cắt tỉa" (prune) mà không mất nhiều chất lượng — chứng tỏ có sự dư thừa nhẹ, nhưng quá trình training cần đa dạng này để tìm ra cách biểu diễn tối ưu.
Ai đang dùng: Tất cả — từ BERT đến GPT-4, LLaMA, Claude, Gemini. Đặc biệt trong Encoder-Decoder như T5, Multi-Head cho phép encoder nhìn toàn bộ input trong khi decoder tập trung vào từng vị trí generate.
Hạn chế:
- Memory: Phải lưu KV cache cho từng head trong inference, làm tăng bộ nhớ tuyến tính theo số head (mặc dù song song về mặt tính toán).
- Redundancy: Một số head học pattern tương tự nhau, đặc biệt ở các lớp sâu, nhưng khó xác định trước head nào thừa.
Đào sâu hơn
- Paper gốc: Attention Is All You Need (Vaswani et al., 2017) — phần 3.2.2 "Multi-Head Attention".
- Cùng cụm:
- Self-Attention — Nền tảng đơn giản trước khi chia multi-head
- Transformer Architecture — Vị trí của MHA trong khối Transformer đầy đủ
- Positional Encoding — Làm việc cùng MHA để giữ thông tin vị trí
- Autoregressive LM — Cách MHA được sử dụng trong GPT để generate text
- Đọc tiếp:
- Flash Attention — Cách tối ưu bộ nhớ I/O cho Multi-Head Attention khi context dài
- Grouped Query Attention — Giảm bộ nhớ KV cache bằng cách chia sẻ K/V giữa các head
- Mamba và SSM — Kiến trúc thay thế Attention, loại bỏ hoàn toàn khái niệm head
Self-Attention — Mỗi token tự hỏi 'ai quan trọng với tôi?'
Hiểu bản chất Self-Attention: mỗi token trong câu tự động 'nhìn' tất cả các token khác để tìm ngữ cảnh, thay vì chỉ xử lý tuần tự như RNN.
Positional Encoding — Transformer không biết thứ tự, phải dạy
Transformer xử lý song song nên mất khái niệm trước/sau. Positional Encoding 'dán số thứ tự' vào embedding để model phân biệt 'chó cắn người' với 'người cắn chó'.