TROISINH
BreakthroughsAttention Efficiency

GQA — Goldilocks giữa MHA và MQA: Bí quyết giảm 4x KV Cache không mất chất lượng

Grouped Query Attention là kỹ thuật chia nhóm attention heads để chia sẻ Key/Value, giảm memory bandwidth trong inference xuống 75% mà vẫn giữ 90%+ chất lượng MHA. Đây là lý do Llama 2 và Mistral chạy được context dài trên hardware consumer.

Inference LLM hiện đại không chết vì thiếu FLOPs — chết vì memory bandwidth. Khi bạn generate token thứ 10.000, GPU phải load toàn bộ KV cache của 9.999 token trước đó từ VRAM, và con số này nổ tung với Multi-Head Attention (MHA) truyền thống. Multi-Query Attention (MQA) thì tiết kiệm nhưng thường làm giảm chất lượng đáng kể. GQA (Grouped Query Attention) tìm được điểm ngọt — giảm 4x lượng memory cần đọc mỗi step, nhưng vẫn giữ gần như đầy đủ khả năng biểu diễn. Đây là lý do Llama 2 và Mistral có thể phục vụ context window lớn trên GPU consumer mà không cần chia model.

Vấn đề

Bottleneck là memory, không phải compute. Trong autoregressive decoding, mỗi token mới cần tính attention với tất cả token trước đó. Với MHA cổ điển, mỗi attention head (thường 64-128 heads) giữ bản sao riêng của Key và Value. Cache size = 2 × số heads × hidden_dim × sequence_length. Với Llama 2 70B (64 heads, 8192 hidden, 4K context), KV cache chiếm hơn 30GB VRAM — chỉ để lưu "trí nhớ" của conversation.

MQA quá cực đoan. Multi-Query Attention gom tất cả query heads về chung 1 K/V head, giảm memory xuống 64 lần. Nhưng điều này tạo bottleneck biểu diễn nghiêm trọng — 64 "đầu đọc" khác nhau buộc phải chia sẻ 1 "bộ nhớ" chung, làm giảm khả năng capture thông tin đa dạng. Kết quả là perplexity tăng, đặc biệt trên task đòi hỏi reasoning phức tạp.

Cần một giải pháp trung gian: chia nhóm để chia sẻ, không phải chia toàn bộ.

Ý tưởng cốt lõi

GQA như thư viện với nhiều phòng đọc thay vì 1 phòng chung hoặc 64 phòng riêng.

Trong MHA, bạn có 64 bản sao hoàn toàn giống nhau của bách khoa toàn thư — lãng phí. Trong MQA, bạn có 1 cuốn sách duy nhất cho 64 người đọc — chen chúc. GQA đặt 8 cuốn sách (g=8 nhóm), mỗi nhóm 8 người đọc chia sẻ 1 cuốn.

Điểm mấu chốt: Query heads vẫn giữ sự đa dạng, vì mỗi head học cách "hỏi" khác nhau (project khác nhau từ hidden state). Nhưng Key và Value là thông tin thô từ input — chúng không cần 64 biến thể. Bằng cách gom 8 query heads vào 1 nhóm để share K/V head, bạn giảm memory bandwidth đi 8 lần, nhưng vẫn giữ 8 "góc nhìn" khác nhau để truy vấn thông tin.

Cách implement trong Llama 2 rất tinh tế: dùng n_rep = num_heads / num_key_value_heads để repeat (broadcast) K/V tensor theo chiều heads. Mỗi query head thấy cùng K/V nhưng dot-product với Q khác nhau, cho kết quả attention pattern riêng. Đơn giản vậy thôi — không cần thay đổi architecture sâu, chỉ là reshape và share.

Tại sao nó hoạt động

1. Memory Bandwidth là mục tiêu chính GPU A100/H100 có ~2 TB/s memory bandwidth nhưng 100+ TFLOPS compute. Khi generate token, bạn không tính toán nhiều — bạn chỉ load KV cache. GQA giảm số lượng K/V tensors từ 64 xuống 8 (với Llama 2 70B), nghĩa là giảm 87.5% bandwidth pressure. Đây là lý do GQA cho phép throughput cao hơn đáng kể trong inference dài.

2. Redundancy trong Attention Heads Phân tích weights của MHA cho thấy các K/V heads học biểu diễn tương quan cao (high cosine similarity). Việc giữ 64 bản sao là over-parameterization không cần thiết cho inference. GQA exploit redundancy này bằng cách force sharing, ép các query heads trong cùng nhóm phải "cộng tác" trên cùng K/V representation — giống như ensemble learning nhưng efficient hơn.

3. Uptraining từ MHA Điểm hay nhất: bạn có thể chuyển model MHA sang GQA bằng mean pooling. Lấy trung bình weights của 8 K heads thành 1 K head mới, tương tự cho V. Sau đó fine-tune chỉ 5% lượng data gốc (vài chục tỷ tokens) để model thích nghi. Điều này cho phép các lab như Meta upgrade Llama 2 từ MHA sang GQA mà không cần pre-train lại từ đầu.

4. Trade-off có kiểm soát Với g = h/8 (8 nhóm), GQA đạt ~90-95% chất lượng MHA trên hầu hết benchmarks, trong khi chỉ tốn 12.5% memory của MHA. Đây là "Goldilocks zone": không quá nóng (MQA, quality loss) không quá lạnh (MHA, memory explosion).

Ý nghĩa thực tế

ArchitectureK/V HeadsMemory (relative)Quality (vs MHA)Use case
MHA64 (full)100%100%Training, research flexibility
GQA8 (groups)12.5%~92%Production inference, long context
MQA1 (shared)1.6%~85%Aggressive compression, edge devices

Thực tế triển khai:

  • Llama 2 70B: Dùng 64 query heads, 8 K/V heads → giảm 4x memory so với MHA tương đương, cho phép 4K context trên single A100 80GB.
  • Mistral 7B: Dùng GQA với sliding window attention, đạt 32K context trên consumer GPU 24GB.
  • Throughput: VLLM và TensorRT-LLM báo cáo 2-3x throughput improvement khi switch từ MHA sang GQA trên long-context workloads.

Các giới hạn: GQA vẫn không thể capture một số long-range dependency phức tạp như MHA, nhưng sự khác biệt thường < 5% perplexity. Trên code generation và math reasoning, GQA thường không thua MHA nếu uptraining đúng cách.

Đào sâu hơn

Paper gốc:

Cùng cụm — Attention Efficiency:

  • Flash Attention — Cùng phép toán, nhanh 3x nhờ tiling trên SRAM, kết hợp với GQA để tối ưu cả compute và memory.
  • Multi-Query Attention — Extreme version của GQA (1 group), hiểu rõ trade-off khi giảm xuống 1 head.
  • Sliding Window Attention — Kết hợp với GQA trong Mistral để xử lý 32K+ context.
  • RoPE — Position encoding thường dùng với GQA; chú ý cách RoPE áp dụng trên shared K heads.
  • ALiBi — Alternative cho long context, có thể combine với GQA.

Đọc tiếp:

  • KV Cache — Hiểu rõ vấn đề memory mà GQA giải quyết; cách quantization kết hợp với GQA để nén thêm.
  • Sequence Modeling (Level 0) — Prerequisite về self-attention mechanism.
  • Long Context (Level 2) — Các kỹ thuật như YaRN kết hợp với GQA để đạt 100K+ context.

On this page