Flash Attention — Cùng phép toán, nhanh 3x nhờ tiling trên SRAM
Flash Attention giải quyết bottleneck bộ nhớ O(n²) của Attention bằng cách tính toán trên SRAM thay vì HBM, giúp LLM xử lý context dài nhanh gấp 3 lần mà không làm mất độ chính xác.
Khi context window của LLM mở rộng từ 4K lên 128K tokens, bạn không thể chỉ "thêm GPU" để giải quyết vấn đề. Bộ nhớ GPU bị đánh chặn bởi một con số tử thần: O(n²). Flash Attention là lời giải được tích hợp mặc định trong PyTorch 2.0 và vLLM, giúp bạn chạy model lớn với context dài gấp 3 lần mà không tốn thêm VRAM, đồng thời là bước đệm bắt buộc để tiến tới long-context hàng triệu token.
Vấn đề
Standard Attention không chậm vì thiếu FLOPs — nó chậm vì quá nhiều chuyến đi tới kho hàng (HBM).
Khi tính attention, bạn tạo ra ma trận QKᵀ (scores) với kích thước N×N. Với sequence dài 4K tokens, đó là 16 triệu phần tử — 32MB dữ liệu. Standard workflow như sau: tính QKᵀ → ghi xuống HBM → đọc lên để softmax → ghi P (attention weights) xuống HBM → đọc lên để nhân với V. Tổng cộng 4 chuyến đi lên xuống, chỉ để xử lý vài vector.
Vấn đề là HBM (High Bandwidth Memory) trên A100 chỉ có ~2 TB/s, trong khi SRAM (on-chip cache) đạt ~20 TB/s — nhanh gấp 10 lần. Attention trở thành bài toán memory-bound: GPU ngồi chờ 95% thời gian để dữ liệu di chuyển, không phải để tính toán.
Ý tưởng cốt lõi
Đây không phải bài toán toán học, đây là bài toán logistics.
Hãy tưởng tượng GPU như một văn phòng: bạn có một cái bàn nhỏ (SRAM) ngay cạnh máy tính, và một kho hàng khổng lồ (HBM) cách đó 100 mét. Standard Attention như thể bạn làm việc trên bàn, nhưng cứ sau mỗi phép tính lại ôm hết giấy tờ chạy ra kho để lưu, rồi chạy lại vào lấy tiếp — dù rằng bạn chỉ cần xử lý chúng ngay tại chỗ.
Flash Attention có một insight "bướng bỉnh" về data locality: "Nếu không cần thiết, đừng bao giờ đặt hộp trở lại kho."
Thay vì tính toàn bộ ma trận QKᵀ cùng lúc, Flash Attention chia nhỏ K và V thành các khối (tiles) vừa với SRAM — ví dụ mỗi khối 64×64. Nó nạp một khối K/V vào bàn làm việc, rồi "tuôn" toàn bộ các khối Q qua đó. Mỗi lần tính partial attention scores, nó áp dụng online softmax — một thuật toán thống kê chạy (running statistics) cho phép cập nhật giá trị max và tổng exponential mà không cần nhìn thấy toàn bộ hàng. Kết quả partial được tích lũy ngay trên chip, và chỉ có output cuối cùng mới được ghi xuống HBM.
Kernel fusion là phần còn lại: thay vì 3 kernel riêng biệt (matmul → softmax → matmul), Flash Attention viết một kernel CUDA duy nhất thực hiện cả chuỗi phép tính trong SRAM. Không có tensor trung gian nào "chạm" HBM cả.
Đó là tất cả. Không có approximate heuristic, không có sparse mask — chỉ đơn giản là: đừng đọc/ghi dữ liệu vô ích.
Tại sao nó hoạt động
Bản chất toán học nằm ở tính kết hợp của softmax. Thay vì tính softmax([x1, x2, ..., xn]) cùng lúc, bạn có thể tính tích lũy từng khối:
m_new = max(m_old, max(x_block))
l_new = l_old * exp(m_old - m_new) + sum(exp(x_block - m_new))Với công thức này, bạn có thể rescale accumulator cũ bằng giá trị max mới, rồi cộng dồn. Điều này cho phép tiling: mỗi khối Q chỉ cần thấy khối K/V hiện tại để cập nhật output, không cần toàn bộ sequence.
Trong backward pass, thay vì lưu ma trận attention P (N² memory), Flash Attention tái tính (recompute) nó on-the-fly từ O (output) và các thống kê softmax đã lưu, đánh đổi ~20% compute để tiết kiệm hàng GB bộ nhớ — một lựa chọn dễ dàng vì compute rẻ còn memory bandwidth đắt.
# Pseudo-code ý tưởng tiling
for i in range(T_r): # Chia Q thành T_r blocks
Q_i = load(Q[i]) # Nạp vào SRAM (~20KB)
m, l = -inf, 0 # Khởi tạo running stats
for j in range(T_c): # Stream qua K,V blocks
K_j = load(K[j])
V_j = load(V[j])
S_ij = Q_i @ K_j.T # Partial scores
# Online softmax update
m_new = max(m, rowmax(S_ij))
P_ij = exp(S_ij - m_new)
l = l * exp(m - m_new) + rowsum(P_ij)
# Cập nhật output accumulator
O_i = diag(exp(m - m_new)) @ O_i + P_ij @ V_j
m = m_new
store(O_i) # Chỉ ghi kết quả cuối xuống HBMÝ nghĩa thực tế
Flash Attention không làm thay đổi độ phức tạp lý thuyết — vẫn là O(n²) FLOPs — nhưng nó chuyển bottleneck từ memory-bound sang compute-bound thực sự. Kết quả thực tế:
| Metric | Standard Attention | Flash Attention |
|---|---|---|
| Memory Activation | O(n²) | O(n) |
| Context 16K trên A100 | Out of Memory | Chạy được |
| Throughput (GPT-2 1K) | Baseline | 3× nhanh |
| Độ chính xác | Exact | Exact (không approximation) |
Ai đang dùng: PyTorch 2.0 (scaled_dot_product_attention tự động chọn Flash Attention khi khả dụng), HuggingFace Transformers, vLLM (inference engine mặc định), và mọi stack training của LLaMA/GPT-4.
Giới hạn:
- Yêu cầu GPU hiện đại (Ampere/Ada/Hopper) với đủ SRAM; lợi ích giảm trên GPU cũ (Turing/Pascal).
- Head dimension bị giới hạn bởi kích thước SRAM (thường d ≤ 128 cho FP16).
- Không phải silver bullet cho attention — nếu bạn cần sparse attention hay linear attention để đạt O(n) thực sự, Flash Attention vẫn là O(n²) (chỉ là IO-efficient).
Đào sâu hơn
-
Paper gốc: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (Tri Dao et al., 2022) — Giới thiệu tiling và online softmax; "FlashAttention-3" (2024) tối ưu cho H100 với FP8 và asynchrony.
-
Cùng cụm (attention-efficiency):
- Multi-Query Attention — Share K/V heads để giảm KV cache
- Grouped-Query Attention — Goldilocks giữa MHA và MQA
- Sliding Window Attention — Chỉ nhìn N token gần nhất
- RoPE — Xoay vector thay vì cộng position
- ALiBi — Phạt khoảng cách trực tiếp
-
Đọc tiếp:
- KV Cache — Cách lưu trữ K/V để decode nhanh hơn
- Long Context (Level 2) — Ring Attention và các kỹ thuật context triệu token
-
External: Tri Dao's blog giải thích FlashAttention-3 và cách tận dụng Tensor Cores.
Knowledge Distillation — Model lớn dạy model nhỏ, nén intelligence
Cách chuyển giao 'bí mật' từ LLM 100B xuống model 7B bằng soft targets: giảm 90% kích thước mà giữ 95% khả năng suy luận.
Multi-Query Attention — Share K/V heads, KV cache nhỏ hẳn
Tại sao LLaMA 2 và PaLM lại chia sẻ Key/Value giữa các attention heads? MQA và GQA giảm KV cache 8-64 lần, biến mô hình 70B từ 'không thể serve' thành 'chạy mượt' trên consumer GPU.