TROISINH
FrontierLong Context

Ring Attention — K/V truyền vòng giữa GPU, context triệu token

Distributed sequence parallelism biến GPU thành vòng tròn xử lý, mở khóa context triệu token bằng cách xoay K/V thay vì all-to-all communication.

Bạn đã có FlashAttention — kỹ thuật giúp attention nhanh gấp 3 lần trên một GPU. Nhưng khi context dài đến mức 1 triệu token, ngay cả FlashAttention cũng bó tay vì single-GPU memory wall. Ring Attention là lời giải: thay vì nhét toàn bộ sequence vào một GPU, nó biến nhiều GPU thành một "vòng truyền" K/V, cho phép training và inference với context dài vô hạn (về lý thuyết) chỉ bằng cách thêm GPU.

Vấn đề

FlashAttention giải quyết vấn đề bộ nhớ O(n²) bằng cách không materialize attention matrix, giữ mọi thứ trong SRAM. Nhưng nó vẫn là single-device algorithm. Với model 8B parameters và context 128K tokens, KV cache cần ~32GB FP16 — đã gần đầy một A100 40GB. Muốn 1 triệu token? Bạn cần 10 GPU chỉ để chứa cache, chưa tính weights và activations.

Cách tiếp cận naive là data parallelism: chia batch ra nhiều GPU. Nhưng attention là toàn cục — mỗi token cần nhìn tất cả token khác. Nếu shard sequence theo chiều dài (sequence parallelism đơn giản), bạn cần all-to-all communication: mỗi GPU phải gửi K/V của mình cho tất cả GPU khác. Với 8 GPU và 1M tokens, đây là bão tố giao tiếp (communication storm) làm chết throughput.

Vậy làm sao để distribute mà không bị nghẽn cổ chai giao tiếp?

Ý tưởng cốt lõi

Ring Attention xoay K/V trong vòng tròn thay vì broadcast toàn bộ.

Thay vì mỗi GPU giữ một đoạn sequence và liên tục all-gather dữ liệu từ peer, Ring Attention sắp xếp GPU thành một vòng tròn topology. Query được shard cố định trên mỗi GPU, còn Key và Value thì "xoay vòng" — mỗi GPU nhận block K/V từ láng giềng trái, tính toán partial attention, rồi truyền sang láng giềng phải.

Đây là "aha moment": Attention là phép lấy trung bình có trọng số, và trung bình có tính kết hợp (associative). Bạn có thể tính trung bình từng phần rồi merge lại, miễn là bạn biết cách xử lý denominator của softmax.

Ring Attention dùng online softmax (cùng engine bên trong FlashAttention) để maintain running statistics: giá trị max hiện tại mm và tổng exponential ll. Khi nhận thêm một block K/V mới từ vòng truyền, nó update statistics và accumulator output mà không cần lưu toàn bộ attention matrix.

Ví dụ cụ thể: Giả sử bạn có 4 GPU và sequence 1M tokens.

  • GPU 0 giữ Q của tokens 0-250K. Nó bắt đầu với block K/V 0-250K của chính mình (local), tính partial attention.
  • Đồng thời, nó nhận block K/V 250K-500K từ GPU 3 (láng giềng trái), tính tiếp.
  • Trong lúc đó, nó gửi block K/V 0-250K sang GPU 1 (láng giềng phải).
  • Sau 4 bước truyền, mỗi GPU đã "thấy" toàn bộ 1M tokens, nhưng chỉ giữ một phần nhỏ trong RAM tại một thời điểm.

Overlap là chìa khóa: Nếu bạn tính toán đủ nhanh (FlashAttention trên block local), thời gian truyền K/V qua NVLink/InfiniBand sẽ bị hide hoàn toàn behind computation. Bạn trả giá ~0% overhead communication nếu balance đúng.

Tại sao nó hoạt động

Toán học đằng sau online softmax: Softmax của vector xxeximexjm\frac{e^{x_i - m}}{\sum e^{x_j - m}} với m=max(x)m = \max(x). Khi bạn chia nhỏ sequence thành nhiều block, bạn không thể tính max toàn cục ngay. Online softmax lưu lại mprevm_\text{prev}lprevl_\text{prev} từ block trước, khi gặp block mới với mnewm_\text{new}, nó rescale lại:

  • mupdated=max(mprev,mnew)m_\text{updated} = \max(m_\text{prev}, m_\text{new})
  • lupdated=emprevmupdatedlprev+emnewmupdatedlnewl_\text{updated} = e^{m_\text{prev} - m_\text{updated}} \cdot l_\text{prev} + e^{m_\text{new} - m_\text{updated}} \cdot l_\text{new}

Output cũng được rescale tương tự. Điều này cho phép streaming reduction: bạn merge partial attention outputs từng block một mà không cần random access toàn bộ K/V.

Vòng tròn vs All-to-All: All-to-all communication có complexity O(P2)O(P^2) với PP là số GPU — mỗi GPU phải nói chuyện với tất cả. Ring topology chỉ có O(P)O(P) point-to-point communication. Với high-bandwidth interconnect (NVLink 900GB/s), truyền tuần tự nhanh hơn broadcast storm rất nhiều.

RoPE và Position Embedding: Vấn đề tinh tế: Khi K/V xoay vòng giữa GPU, position embeddings (RoPE) phải được apply đúng. Nếu GPU 1 nhận K từ GPU 0 (vốn là tokens 0-250K), nó phải rotate đúng góc mθm\theta cho từng position mm. Ring Attention xử lý điều này bằng cách lưu metadata position cùng với K/V blocks, hoặc recompute RoPE on-the-fly.

Blockwise Backward: Trong backpropagation, bạn không thể lưu toàn bộ activation vì sẽ tái lại vấn đề O(n²) memory. Ring Attention recompute attention statistics trong backward pass, giống FlashAttention, nhưng làm việc này distributed — mỗi GPU chỉ recompute trên block K/V đang "đến thăm" nó.

Ý nghĩa thực tế

Mở khóa 10M+ tokens: Với 64 GPU, Ring Attention cho phép training với context length tỷ lệ thuận với số GPU. Paper gốc demonstrate 1M+ tokens trên cluster nhỏ, và theoretically có thể scale lên vô hạn (practically bị giới hạn bởi network bandwidth và load balancing).

Trade-off thực tế:

  • Tốc độ: ~23% slower than single-GPU FlashAttention do overhead communication và causal inefficiency (trong autoregressive models, ~50% FLOPs bị waste vì masked positions, nhưng Ring Attention vẫn phải truyền cả phần bị mask).
  • Phần cứng: Yêu cầu interconnect tốc độ cao (NVLink, InfiniBand). Trên Ethernet thông thường, communication không thể hide behind computation, làm mất lợi ích.
  • Complexity: Khó debug hơn FlashAttention đơn lẻ. Position synchronization, load balancing với sequence lengths khác nhau (variable length) là thách thức.

So sánh với các kỹ thuật long-context khác:

  • YaRN / LongRoPE: Extrapolate context bằng cách "kéo dãn" position embeddings, nhưng vẫn chạy trên single GPU, bị giới hạn bởi VRAM.
  • Native Long-Context: Train model từ đầu với context dài, nhưng vẫn cần distributed training để fit vào memory — Ring Attention là engine cho phép điều đó.
  • Linear Attention / Mamba: Thay đổi architecture để tránh O(n²), nhưng trade-off chất lượng. Ring Attention giữ nguyên exact softmax attention, chỉ distribute nó.

Đào sâu hơn

Paper gốc:

  • Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context" (arXiv:2310.01889, 2023) — Phát minh ra ring topology và online softmax distributed.
  • "Striped Attention: Faster Ring Attention for Causal Transformers" (arXiv:2311.09431, 2023) — Tối ưu causal masking để giảm waste FLOPs.

Cùng cụm (Long Context):

  • YaRN & LongRoPE — Kéo dài context window sau training bằng position interpolation thông minh.
  • Native Long-Context Training — Train model với context triệu token từ đầu, cần Ring Attention để thực thi.

Đọc tiếp:

  • Flash Attention — Nền tảng O(n) memory cho single GPU, prerequisite để hiểu Ring Attention.
  • KV Cache — Vấn đề memory mà Ring Attention giải quyết.
  • Inference Frontier — Nơi các kỹ thuật như Ring Attention được triển khai production.

On this page