GQA — Goldilocks giữa MHA và MQA: Bí quyết giảm 4x KV Cache không mất chất lượng
Grouped Query Attention là kỹ thuật chia nhóm attention heads để chia sẻ Key/Value, giảm memory bandwidth trong inference xuống 75% mà vẫn giữ 90%+ chất lượng MHA. Đây là lý do Llama 2 và Mistral chạy được context dài trên hardware consumer.
Inference LLM hiện đại không chết vì thiếu FLOPs — chết vì memory bandwidth. Khi bạn generate token thứ 10.000, GPU phải load toàn bộ KV cache của 9.999 token trước đó từ VRAM, và con số này nổ tung với Multi-Head Attention (MHA) truyền thống. Multi-Query Attention (MQA) thì tiết kiệm nhưng thường làm giảm chất lượng đáng kể. GQA (Grouped Query Attention) tìm được điểm ngọt — giảm 4x lượng memory cần đọc mỗi step, nhưng vẫn giữ gần như đầy đủ khả năng biểu diễn. Đây là lý do Llama 2 và Mistral có thể phục vụ context window lớn trên GPU consumer mà không cần chia model.
Vấn đề
Bottleneck là memory, không phải compute. Trong autoregressive decoding, mỗi token mới cần tính attention với tất cả token trước đó. Với MHA cổ điển, mỗi attention head (thường 64-128 heads) giữ bản sao riêng của Key và Value. Cache size = 2 × số heads × hidden_dim × sequence_length. Với Llama 2 70B (64 heads, 8192 hidden, 4K context), KV cache chiếm hơn 30GB VRAM — chỉ để lưu "trí nhớ" của conversation.
MQA quá cực đoan. Multi-Query Attention gom tất cả query heads về chung 1 K/V head, giảm memory xuống 64 lần. Nhưng điều này tạo bottleneck biểu diễn nghiêm trọng — 64 "đầu đọc" khác nhau buộc phải chia sẻ 1 "bộ nhớ" chung, làm giảm khả năng capture thông tin đa dạng. Kết quả là perplexity tăng, đặc biệt trên task đòi hỏi reasoning phức tạp.
Cần một giải pháp trung gian: chia nhóm để chia sẻ, không phải chia toàn bộ.
Ý tưởng cốt lõi
GQA như thư viện với nhiều phòng đọc thay vì 1 phòng chung hoặc 64 phòng riêng.
Trong MHA, bạn có 64 bản sao hoàn toàn giống nhau của bách khoa toàn thư — lãng phí. Trong MQA, bạn có 1 cuốn sách duy nhất cho 64 người đọc — chen chúc. GQA đặt 8 cuốn sách (g=8 nhóm), mỗi nhóm 8 người đọc chia sẻ 1 cuốn.
Điểm mấu chốt: Query heads vẫn giữ sự đa dạng, vì mỗi head học cách "hỏi" khác nhau (project khác nhau từ hidden state). Nhưng Key và Value là thông tin thô từ input — chúng không cần 64 biến thể. Bằng cách gom 8 query heads vào 1 nhóm để share K/V head, bạn giảm memory bandwidth đi 8 lần, nhưng vẫn giữ 8 "góc nhìn" khác nhau để truy vấn thông tin.
Cách implement trong Llama 2 rất tinh tế: dùng n_rep = num_heads / num_key_value_heads để repeat (broadcast) K/V tensor theo chiều heads. Mỗi query head thấy cùng K/V nhưng dot-product với Q khác nhau, cho kết quả attention pattern riêng. Đơn giản vậy thôi — không cần thay đổi architecture sâu, chỉ là reshape và share.
Tại sao nó hoạt động
1. Memory Bandwidth là mục tiêu chính GPU A100/H100 có ~2 TB/s memory bandwidth nhưng 100+ TFLOPS compute. Khi generate token, bạn không tính toán nhiều — bạn chỉ load KV cache. GQA giảm số lượng K/V tensors từ 64 xuống 8 (với Llama 2 70B), nghĩa là giảm 87.5% bandwidth pressure. Đây là lý do GQA cho phép throughput cao hơn đáng kể trong inference dài.
2. Redundancy trong Attention Heads Phân tích weights của MHA cho thấy các K/V heads học biểu diễn tương quan cao (high cosine similarity). Việc giữ 64 bản sao là over-parameterization không cần thiết cho inference. GQA exploit redundancy này bằng cách force sharing, ép các query heads trong cùng nhóm phải "cộng tác" trên cùng K/V representation — giống như ensemble learning nhưng efficient hơn.
3. Uptraining từ MHA Điểm hay nhất: bạn có thể chuyển model MHA sang GQA bằng mean pooling. Lấy trung bình weights của 8 K heads thành 1 K head mới, tương tự cho V. Sau đó fine-tune chỉ 5% lượng data gốc (vài chục tỷ tokens) để model thích nghi. Điều này cho phép các lab như Meta upgrade Llama 2 từ MHA sang GQA mà không cần pre-train lại từ đầu.
4. Trade-off có kiểm soát
Với g = h/8 (8 nhóm), GQA đạt ~90-95% chất lượng MHA trên hầu hết benchmarks, trong khi chỉ tốn 12.5% memory của MHA. Đây là "Goldilocks zone": không quá nóng (MQA, quality loss) không quá lạnh (MHA, memory explosion).
Ý nghĩa thực tế
| Architecture | K/V Heads | Memory (relative) | Quality (vs MHA) | Use case |
|---|---|---|---|---|
| MHA | 64 (full) | 100% | 100% | Training, research flexibility |
| GQA | 8 (groups) | 12.5% | ~92% | Production inference, long context |
| MQA | 1 (shared) | 1.6% | ~85% | Aggressive compression, edge devices |
Thực tế triển khai:
- Llama 2 70B: Dùng 64 query heads, 8 K/V heads → giảm 4x memory so với MHA tương đương, cho phép 4K context trên single A100 80GB.
- Mistral 7B: Dùng GQA với sliding window attention, đạt 32K context trên consumer GPU 24GB.
- Throughput: VLLM và TensorRT-LLM báo cáo 2-3x throughput improvement khi switch từ MHA sang GQA trên long-context workloads.
Các giới hạn: GQA vẫn không thể capture một số long-range dependency phức tạp như MHA, nhưng sự khác biệt thường < 5% perplexity. Trên code generation và math reasoning, GQA thường không thua MHA nếu uptraining đúng cách.
Đào sâu hơn
Paper gốc:
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (2023) — Joshua Ainslie et al., Google Research. Trình bày thuật toán uptraining và phân tích trade-off giữa số nhóm (g) và chất lượng.
Cùng cụm — Attention Efficiency:
- Flash Attention — Cùng phép toán, nhanh 3x nhờ tiling trên SRAM, kết hợp với GQA để tối ưu cả compute và memory.
- Multi-Query Attention — Extreme version của GQA (1 group), hiểu rõ trade-off khi giảm xuống 1 head.
- Sliding Window Attention — Kết hợp với GQA trong Mistral để xử lý 32K+ context.
- RoPE — Position encoding thường dùng với GQA; chú ý cách RoPE áp dụng trên shared K heads.
- ALiBi — Alternative cho long context, có thể combine với GQA.
Đọc tiếp:
- KV Cache — Hiểu rõ vấn đề memory mà GQA giải quyết; cách quantization kết hợp với GQA để nén thêm.
- Sequence Modeling (Level 0) — Prerequisite về self-attention mechanism.
- Long Context (Level 2) — Các kỹ thuật như YaRN kết hợp với GQA để đạt 100K+ context.
Multi-Query Attention — Share K/V heads, KV cache nhỏ hẳn
Tại sao LLaMA 2 và PaLM lại chia sẻ Key/Value giữa các attention heads? MQA và GQA giảm KV cache 8-64 lần, biến mô hình 70B từ 'không thể serve' thành 'chạy mượt' trên consumer GPU.
Sliding Window Attention — Chỉ nhìn N token gần nhất, đủ rồi
Giảm độ phức tạp attention từ O(n²) xuống O(n×w) bằng cách giới hạn mỗi token chỉ nhìn w token lân cận. Cách Mistral 7B chạy 32K context trên GPU 24GB.