Multi-Query Attention — Share K/V heads, KV cache nhỏ hẳn
Tại sao LLaMA 2 và PaLM lại chia sẻ Key/Value giữa các attention heads? MQA và GQA giảm KV cache 8-64 lần, biến mô hình 70B từ 'không thể serve' thành 'chạy mượt' trên consumer GPU.
Bạn đang chạy inference với model 70B và context 32K tokens. Mỗi token generated ra đòi hỏi phải load toàn bộ KV cache từ GPU memory — với Multi-Head Attention (MHA) truyền thống, đó là 64 bản sao của history, khiến bạn cần 80GB VRAM chỉ để lưu "trí nhớ" chưa tính đến weights. Multi-Query Attention (MQA) và biến thể Grouped-Query Attention (GQA) là câu trả lời: thay vì mỗi query head có bộ K/V riêng, hãy chia sẻ chúng. Bằng cách này, bạn cắt giảm memory bandwidth bottleneck — không phải bằng cách tính toán ít đi, mà bằng cách đọc ít dữ liệu hơn từ VRAM.
Vấn đề
Autoregressive generation là một bài toán bandwidth, không phải compute.
Trong MHA chuẩn, mỗi attention head có bộ ba Q/K/V riêng. Với 64 heads, mỗi token sinh ra cần load 64 Key vectors và 64 Value vectors từ High-Bandwidth Memory (HBM) — đó là 128 tensor riêng biệt cho mỗi layer. Khi context dài ra (32K, 128K tokens), lượng data cần đọc cho mỗi token mới tăng tuyến tính, bão hòa hoàn toàn bandwidth của GPU (A100 chỉ có ~2 TB/s HBM vs 19-312 TFLOPS compute).
Vấn đề không phải là "tính toán chậm" — GPU đang ngồi chơi 95% thời gian chờ data từ VRAM. Với model 70B và context 128K, KV cache chiếm >50GB VRAM, khiến việc serve trở nên bất khả thi trên phần cứng thông thường.
Ý tưởng cốt lõi
Query heads cung cấp tất cả sự đa dạng bạn cần.
Đây là insight cốt lõi: trong MHA, các K/V heads thường học các representation rất tương đồng — chúng đều đang cố gắng mã hóa "token này mang thông tin gì". Sự đa dạng thực sự nằm ở cách bạn truy vấn (Query), không phải ở cách bạn lưu trữ thông tin.
Multi-Query Attention (MQA) đẩy logic này đến cực đoan: chỉ dùng một Key head và một Value head cho tất cả 64 Query heads. Kết quả: KV cache giảm từ O(h × d × L) xuống O(d × L), tức là 64 lần nhỏ hơn.
Nhưng thực tế sản xuất (như LLaMA 2) dùng Grouped-Query Attention (GQA) — "Goldilocks solution" ở giữa: chia 64 query heads thành 8 nhóm, mỗi nhóm chia sẻ 1 K/V head. Bạn vẫn giảm 8x memory, nhưng mỗi query head chỉ "cạnh tranh" với 7 anh em thay vì 63, giữ lại đủ capacity để tránh suy giảm chất lượng.
Analogy: Thư viện và kính lúp
Hãy tưởng tượng bạn có 64 độc giả (query heads) đang tra cứu một bộ sách tham khảo (KV cache). MHA như cho mỗi người một bản photo riêng của toàn bộ thư viện — lãng phí khủng khiếp. MQA là đặt một cuốn sách duy nhất giữa phòng, ai cũng nhìn chung vào đó — tiết kiệm nhưng hỗn loạn khi 64 người chen lấn. GQA là đặt 8 bản sao sách (mỗi nhóm 8 người dùng chung), vừa tiết kiệm 8x lưu trữ, vừa giảm chen lấn.
Tại sao nó hoạt động
1. Sự tương đồng của K/V projections
Trong thực nghiệm, các K/V projections trong MHA học các vector rất tương quan (high cosine similarity) vì chúng đều cố gắng mã hóa semantic content của token. Việc chia sẻ chúng không mất mát nhiều thông tin vì thông tin thực sự nằm ở cách Query "hỏi" thông tin đó — thông qua dot product Q·K.
2. Memory Wall vs Compute
MQA/GQA không giảm FLOPs (vẫn cần tính attention scores cho tất cả heads), nhưng giảm memory bandwidth — đóng vai trò then chót trong inference. Khi KV cache nhỏ đi 8-64 lần, nó có thể fit vào SRAM hoặc ít nhất là giảm pressure trên HBM, cho phép batch size lớn hơn hoặc context dài hơn.
3. Uptraining recipe
Bạn có thể chuyển một model MHA sang GQA bằng cách mean pooling trọng số K/V của các heads trong cùng nhóm, sau đó fine-tune trên ~5% dữ liệu pretraining ban đầu. Cách này phục hồi >90% chất lượng so với MHA gốc trong khi giữ lại toàn bộ lợi ích tốc độ.
4. Cơ chế attention không bị phá vỡ
Softmax attention vẫn hoạt động bình thường với shared K/V. Mỗi query head tính dot product với cùng một Key vector, nhưng vì Query vectors khác nhau, attention scores vẫn đa dạng. Thậm chí, việc này có thể coi là "regularization" — ép các query heads học cách truy xuất thông tin hiệu quả hơn thay vì dựa vào K/V "độc quyền".
Ý nghĩa thực tế
Impact thực tế:
- Throughput: Fireworks.ai báo cáo 11x improvement throughput và giảm 30% latency khi chuyển sang MQA/GQA trên production workloads.
- Democratization: Model 70B với GQA có thể serve context 32K+ trên consumer GPU (RTX 4090 24GB) thay vì cần 8x A100.
- Adoption: PaLM (540B) dùng MQA "cứng"; LLaMA 2 (70B) và Mistral dùng GQA (8 groups/64 heads) như giải pháp thực tiễn tốt nhất.
So sánh nhanh:
| Kiến trúc | Số K/V heads | KV Cache Size | Chất lượng | Use case |
|---|---|---|---|---|
| MHA (Llama 1) | 64 | 100% (baseline) | Cao nhất | Research, training |
| GQA (Llama 2) | 8 | 12.5% | ~95% MHA | Production serving |
| MQA (PaLM) | 1 | 1.6% | ~90% MHA | Cực kỳ dài context, high-latency tolerant |
Các giới hạn:
- Chất lượng: MQA thuần túy (1 head) có thể gây perplexity tăng và training instability. GQA là sweet spot thực tế.
- Không giúp training: Chỉ tối ưu inference memory. Training vẫn cần full attention với materialized K/V cho backprop.
- Alignment quan trọng: Draft model và target model cần tương thích (cùng họ kiến trúc) nếu dùng với Speculative Decoding.
Đào sâu hơn
Paper gốc:
- "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019) — Giới thiệu MQA.
- "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023) — Uptraining recipe và phân tích trade-off.
Bài liên quan TroiSinh:
Cùng cụm (attention-efficiency):
- Flash Attention — Tối ưu bandwidth bằng tiling SRAM, bổ trợ hoàn hảo cho MQA.
- Grouped-Query Attention — Phân tích sâu về "Goldilocks" 8-group strategy.
- Sliding Window Attention — Kết hợp với MQA để xử lý context cực dài (>100k).
- RoPE — Position encoding tương thích với shared K/V heads.
- ALiBi — Alternative cho context extrapolation khi dùng MQA.
Đọc tiếp:
- KV Cache — Hiểu sâu về cơ chế caching và quantization cho K/V.
- Context Extrapolation — Làm thế nào MQA/GQA mở đường cho 1M+ token contexts (Level 2).
External resources:
- Fireworks.ai benchmarks — Production metrics thực tế.
- Tinkerd visual guide — Minh họa code chi tiết MHA→MQA transition.
Flash Attention — Cùng phép toán, nhanh 3x nhờ tiling trên SRAM
Flash Attention giải quyết bottleneck bộ nhớ O(n²) của Attention bằng cách tính toán trên SRAM thay vì HBM, giúp LLM xử lý context dài nhanh gấp 3 lần mà không làm mất độ chính xác.
GQA — Goldilocks giữa MHA và MQA: Bí quyết giảm 4x KV Cache không mất chất lượng
Grouped Query Attention là kỹ thuật chia nhóm attention heads để chia sẻ Key/Value, giảm memory bandwidth trong inference xuống 75% mà vẫn giữ 90%+ chất lượng MHA. Đây là lý do Llama 2 và Mistral chạy được context dài trên hardware consumer.