Linear Attention — Đổi kernel function, quadratic → linear
Thay đổi hàm kernel trong attention để giảm độ phức tạp từ O(N²) xuống O(N), mở khóa khả năng xử lý context triệu token mà không cần GPU khổng lồ.
Bạn đã từng thử chạy Llama 3 với context 1 triệu token và nhận ra GPU 80GB cũng không đủ? Vấn đề không phải ở model quá lớn, mà ở cách attention truyền thống tính toán — nó tốn bộ nhớ theo bình phương độ dài chuỗi. Linear Attention giải quyết điều này bằng một thủ thuật toán học tinh tế: thay vì so sánh từng cặp token riêng lẻ (quadratic), nó tạo một "bản tóm tắt" nén của toàn bộ quá khứ, cho phép mỗi token mới chỉ cần nhìn vào tóm tắt đó thay vì toàn bộ lịch sử (linear). Đây là bí mật đằng sau những model có thể đọc cả cuốn sách trong một lượt xử lý.
Vấn đề
Standard Softmax Attention bị ràng buộc bởi độ phức tạp O(N²). Khi bạn xử lý chuỗi dài N token, attention mechanism phải tính ma trận similarity giữa mọi cặp token, tạo ra một ma trận N×N. Với N = 100K, đây là 10 tỷ phần tử; với N = 1M, con số là 1 nghìn tỷ — tương đương ~4TB bộ nhớ chỉ để lưu ma trận trung gian. Điều này khiến việc training và inference với context dài trở nên bất khả thi trên phần cứng thông thường.
Trong inference autoregressive, vấn đề còn tồi tệ hơn. Mỗi lần sinh token mới, model phải load toàn bộ KV cache (Key và Value của tất cả token trước đó) từ HBM để tính attention. Chi phí này tuyến tính O(N) với độ dài sequence — với 1M tokens, bạn phải load hàng chục GB dữ liệu chỉ để sinh một token tiếp theo. GPU ngồi chờ bộ nhớ (memory-bound) thay vì tính toán.
Các giải pháp như Flash Attention giúp tối ưu IO bằng cách không materialize toàn bộ ma trận N×N, nhưng vẫn không thay đổi bản chất quadratic complexity — nó chỉ khiến O(N²) đỡ tốn kém hơn, chứ không biến nó thành O(N).
Ý tưởng cốt lõi
Linear Attention thay đổi hàm kernel để biến attention thành một RNN với hidden state cố định.
Thay vì dùng exponential similarity exp(QK^T) (softmax), Linear Attention sử dụng một kernel tách rời (separable kernel) φ(Q)φ(K)^T, trong đó φ (phi) là một feature map đơn giản như elu(x)+1 hoặc ReLU. Điểm then chốt nằm ở tính chất kết hợp (associativity) của phép nhân ma trận:
Thay vì tính (φ(Q)φ(K)^T)V (cần materialize ma trận N×N), ta tính φ(Q)(φ(K)^TV). Phép toán φ(K)^TV tạo ra một ma trận S kích thước d×d (với d là chiều hidden), độc lập với N. Ma trận S này đóng vai trò như một hidden state nén — nó chứa tóm tắt toàn bộ thông tin từ quá khứ.
Đây chính là "aha moment":
Softmax Attention giống như một RNN có state vô hạn — nó lưu trữ toàn bộ KV cache của mọi token trước đó. Linear Attention là một RNN có state hữu hạn — nó chỉ lưu một ma trận tóm tắt S cố định kích thước, và mỗi token mới chỉ cần update S bằng phép cộng đơn giản:
S_t = S_{t-1} + φ(k_t)^T v_t.
Khi sinh token mới, thay vì phải dot product với N vectors, model chỉ cần nhân Query hiện tại với ma trận S: output = φ(q_t) S_t. Độ phức tạp giảm từ O(N²d) xuống O(Nd²), và quan trọng hơn, inference mỗi bước chỉ còn O(d²) — hằng số, không phụ thuộc vào độ dài context. Bạn có thể xử lý 1M token hay 10M token với cùng một lượng tính toán cho bước cuối cùng.
Causal masking (chỉ nhìn quá khứ) được tự động thực hiện bởi phép cộng dồn Σ_{i=1}^t — không cần mask matrix đặc biệt. Mỗi token mới chỉ thêm thông tin vào state, không bao giờ cần sửa lại quá khứ.
Tại sao nó hoạt động
Kernel Trick và Geometric Interpretation:
Feature map φ biến không gian vector sang một không gian mới nơi inner product tương quan với similarity cần tính. Với φ(x) = elu(x)+1, ta có một kernel xấp xỉ Gaussian mà không cần tính exp — phép toán expensive trong attention truyền thống. Quan trọng hơn, vì kernel tách rời thành tích của hai hàm độc lập (của Q và K riêng biệt), ta có thể tách phép toán và pre-compute phần chứa K và V.
Online Normalization:
Softmax đòi hỏi normalization toàn cục (chia cho tổng exp). Linear Attention dùng online softmax — duy trì running statistics (max và sum) để normalize từng bước mà không cần nhìn trước toàn bộ sequence. Điều này cho phép recurrent form hoạt động mà không lưu toàn bộ lịch sử.
Trade-off: Low-Rank Dilemma:
Ma trận S có rank bị giới hạn bởi kích thước d×d. Trong khi standard attention có thể truy cập ngẫu nhiên (random access) vào bất kỳ token nào trong quá khứ, Linear Attention bị giới hạn bởi "capacity" của state matrix — giống như cố gắng nén một thư viện 1 triệu cuốn sách vào một bản tóm tắt 64×64 trang. Đây là lý do các biến thể hiện đại như Gated Linear Attention (GLA) thêm gating mechanisms để quên/selective update, giải quyết vấn đề "flat attention" (attention distribution trở nên quá flat, không tập trung vào chi tiết quan trọng).
Ý nghĩa thực tế
Hiệu năng thực tế:
- 4000× nhanh hơn trên các chuỗi rất dài (N > 10,000) so với standard attention trong autoregressive decoding
- 6× speedup tại context 1M tokens so với Flash Attention
- Giảm memory từ O(N²) xuống O(Nd): với d=1024 và N=1M, memory giảm từ ~4TB xuống ~4GB
Giới hạn:
- Diffuse attention: Kernel approximation tạo attention weights "phẳng" hơn softmax — khó tập trung vào một token cụ thể ở xa (needle-in-haystack problem)
- Numerical stability: Cần xử lý cẩn thận để tránh vanishing/exploding states trong chuỗi cực dài
- Không phải drop-in replacement: Cần training từ đầu với architecture này, không thể chuyển đổi model đã train sang linear attention một cách dễ dàng
Ai đang dùng:
- RWKV và RetNet: Early adopters của linear attention form
- Mamba-2: Kết hợp SSM với Linear Attention qua State Space Duality (SSD)
- Based và Gated Linear Attention (GLA): Các model mới đạt chất lượng comparable với Transformers trên long-context tasks
Đào sâu hơn
Paper gốc:
- "Fast Autoregressive Transformers with Linear Attention" (Katharopoulos et al., 2020) — arXiv:2006.16236 — Công trình nền tảng giới thiệu kernel feature map trick
- "Gated Linear Attention Transformers with Hardware-Efficient Training" (Yang et al., 2023) — arXiv:2312.06635 — Thêm gating để fix "flat attention" problem
- "Breaking the Low-Rank Dilemma of Linear Attention" (2024) — arXiv:2411.07635 — Phân tích lý do và cách khắc phục capacity limitation
Bài liên quan TroiSinh:
Cùng cụm (New Architectures):
- Mamba & SSMs — Selective scan thay attention, linear scaling — architecture song song với Linear Attention nhưng dùng State Space Models
- Mamba-2 — SSM = structured matrix multiply, kết nối trực tiếp với Linear Attention qua duality framework
- Hybrid Attention-SSM — Jamba và Zamba, kết hợp SSM hiệu quả với attention chính xác cho long context
- Diffusion LM — Noise-and-denoise cho text, generate song song — một hướng khác để thoát khỏi autoregressive bottleneck
Đọc tiếp:
- Flash Attention — Cùng phép toán, nhanh 3x nhờ tiling trên SRAM — cách tối ưu O(N²) thay vì giảm complexity xuống O(N)
- Ring Attention — K/V truyền vòng giữa GPU, context triệu token — cách khác để xử lý sequence cực dài bằng cách phân tán computation
- KV Cache — Hiểu sâu hơn về bottleneck memory mà Linear Attention giải quyết
Hybrid Attention-SSM (Jamba, Zamba) — SSM + attention khi cần precision
Giải mã kiến trúc lai giữa Mamba SSM và Transformer Attention: tận dụng tốc độ tuyến tính của selective scan để nén context, xen kẽ attention layers để giữ độ chính xác khi cần nhớ chi tiết xa xôi.
Diffusion LM — Noise-and-denoise cho text, generate song song
Hiểu bản chất Diffusion Language Models: tại sao denoise liên tục trong không gian embedding lại cho phép generate text song song thay vì từng token, mở khóa khả năng edit và kiểm soát cao hơn autoregressive.