Speculative Decoding — Model nhỏ draft, model lớn verify song song
Dùng model nhỏ dự đoán trước 5-10 token, rồi để model lớn verify cả chuỗi trong một forward pass — giảm latency 2-3x mà không làm giảm chất lượng output.
Tại sao GPU hiện đại tính toán nhanh đến vậy mà LLM vẫn chậm? Bởi vì inference là bài toán memory-bound, không phải compute-bound — GPU ngồi chờ dữ liệu từ VRAM nhiều hơn là tính toán. Speculative decoding là mẹo biến "thời gian chờ" thành "tính toán song song" bằng cách dùng một model nhỏ đoán trước, rồi để model lớn kiểm tra cả loạt cùng lúc.
Vấn đề
Autoregressive generation là quá trình tuần tự từng token một. Để sinh token thứ n, model phải load toàn bộ weights (có thể 70B-140B parameters) qua GPU memory bus chỉ để tính một vector logits duy nhất. Trên A100, bạn có ~2 TB/s memory bandwidth nhưng ~300+ TFLOPS compute — nghĩa là GPU có thể thực hiện hàng trăm phép tính trong thời gian chờ 1 byte dữ liệu từ VRAM. Kết quả: GPU utilization thường dưới 10%, phần lớn thời gian là "dead time" chờ weights load xong.
Các giải pháp như quantization (giảm kích thước weight) hay KV Cache (tránh tính lại attention) chỉ giảm chi phí mỗi bước, nhưng không xóa bỏ tính chất tuần tự — bạn vẫn phải đi từng bước một. Muốn nhanh hơn, cần sinh nhiều token trong một lần forward, nhưng model lớn không thể "nhìn trước" tương lai để tự sinh chuỗi dài (do autoregressive masking).
Ý tưởng cốt lõi
Sử dụng hai model thay vì một.
Thay vì ép model 70B sinh từng token chậm rãi, ta cho một "model nháp" nhỏ (ví dụ 7B hoặc thậm chí 1B) chạy nhanh trước, tự do sinh ra γ (gamma) token tiếp theo — thường 3-5 token. Sau đó, model lớn (target model) chạy một forward pass duy nhất để verify cả chuỗi γ token đó cùng lúc. Nó so sánh distribution của mình với distribution draft đã sinh: nếu khớp, chấp nhận (accept) hết; nếu gặp token đầu tiên bị lệch, reject từ đó trở đi và resample từ distribution chính xác của model lớn.
Đây chính là "aha moment": Verify song song 5 token rẻ hơn nhiều so với Generate tuần tự 5 token.
Khi verify, model lớn tính logits cho cả 5 vị trí tương lai cùng lúc bằng matrix multiplication — tận dụng được parallelism của GPU. Còn khi generate tuần tự, mỗi bước phải load weights riêng lẻ. Ví dụ: verify 5 token có thể chỉ tốn 20% thời gian so với generate 5 token riêng lẻ, miễn là draft model đoán đúng ít nhất 60-70% số token.
Mẹo này gọi là speculative decoding — tương tự như "speculative execution" trong CPU: thực hiện optimistic execution trước, rồi rollback nếu sai. Điều quan trọng là nhờ kỹ thuật rejection sampling, output cuối cùng có distribution hoàn toàn giống với việc dùng model lớn sinh từng token (lossless), không như các heuristic greedy hay beam search thay đổi distribution.
Tại sao nó hoạt động
Thống kê của "Easy vs Hard tokens": Trong văn bản tự nhiên, khoảng 60-80% token là "dễ đoán" — như ngắt câu, từ nối, từ chức năng (the, and, of). Chỉ 20-40% là "khó" — tên riêng, số liệu, từ chuyên ngành. Model nhỏ (7B) đủ sức đoán đúng phần lớn token dễ, nhưng sẽ fail ở token khó. Model lớn (70B) chỉ cần "sửa" những chỗ sai đó.
Tốc độ tăng lên phụ thuộc vào acceptance rate (α) — xác suất model lớn chấp nhận token của draft. Công thức tốc độ kỳ vọng là:
Speedup ≈ 1 / (1 - α^γ)Với α = 0.8 (chấp nhận 80%) và γ = 4, tốc độ tăng ~2.5x. Nếu draft quá tệ (α < 0.5), lợi ích mất đi vì liên tục phải reject và resample.
Tại sao không dùng draft model đứng alone? Vì quality drop. Draft model nhỏ thường hallucinate hoặc sai syntax. Bằng cách dùng nó chỉ để "đề xuất" và để model lớn "phê duyệt", ta giữ được chất lượng của model lớn nhưng tốc độ gần model nhỏ.
Memory Bandwidth vs Compute: Mỗi lần forward của model lớn để verify là compute-intensive (matrix multiply), không phải memory-bound như generate tuần tự. Vì khi verify, ta dùng cùng một bộ weights để tính cho cả chuỗi draft, tận dụng được data locality trong SRAM. Đây là cách biến memory-bound problem thành compute-bound — đánh đổi FLOPs dư thừa (GPU có nhiều) để giảm memory round-trips (bottleneck thực sự).
Ý nghĩa thực tế
Hiệu quả thực tế: Trong production (Google Search AI Overviews, IBM Watsonx), speculative decoding mang lại 2-3x speedup cho latency end-to-end, đặc biệt với batch size nhỏ (interactive chat). Ví dụ: Llama 3 70B từ 20 token/s lên 50-60 token/s khi dùng Llama 3 8B làm draft.
Rủi ro và giới hạn:
- VRAM pressure: Phải load cả hai model cùng lúc (draft + target) và giữ cả hai KV cache. Với model 70B + 7B, cần thêm ~10-15GB VRAM so với chạy đơn.
- Alignment requirement: Draft và target phải "nói cùng ngôn ngữ" — thường là cùng họ (Llama 3 8B draft cho Llama 3 70B target). Dùng Mistral draft cho Llama target cho acceptance rate thấp (~30%), biến lợi thành hại.
- Batch size conflict: Khi batch size lớn (offline inference), GPU đã bão hòa compute, thêm draft model chỉ gây contention. Speculative decoding chỉ hiệu quả với batch size nhỏ (online serving).
Tiến hóa: EAGLE. Thay vì dùng model nhỏ riêng biệt, EAGLE gắn một "draft head" nhẹ (1-2 layer MLP) vào chính model lớn tại layer giữa, đọc hidden states nội bộ để đoán token tiếp theo. Đạt 3x speedup mà không cần VRAM cho model thứ hai, acceptance rate ~70-75%.
Đào sâu hơn
Paper gốc: Leviathan et al., "Fast Inference from Transformers via Speculative Decoding" (2022) — đặt nền móng cho kỹ thuật rejection sampling và phân tích speedup lý thuyết.
Bài liên quan TroiSinh:
- Cùng cụm:
- KV Cache — Cơ chế cache K/V để tránh tính lại, nền tảng cho speculative decoding
- PagedAttention — Quản lý KV cache hiệu quả khi dùng speculative decoding với nhiều request
- Continuous Batching — Kết hợp với speculative decoding để tối ưu throughput hệ thống
- Đọc tiếp:
- EAGLE — Draft head gắn vào internal layer, cách tiếp cận hiện đại thay thế draft model riêng biệt
- Quantization — Giảm VRAM để chứa cả draft và target model trên cùng GPU
- Beam Search — Phương pháp tìm kiếm thay thế, trade-off quality vs speed khác với speculative
External resources:
- IBM PyTorch Blog: Hitchhiker's Guide to Speculative Decoding — Hướng dẫn triển khai với vLLM và TGI
- Google Research Retrospective — Kinh nghiệm production từ Google AI Overviews
Prefix Caching — System prompt giống nhau? Tính 1 lần, dùng ngàn lần
Kỹ thuật Prefix Caching giúp giảm 90% chi phí inference bằng cách tái sử dụng KV cache cho phần prompt chung, biến GPU từ máy tính lại thành máy tra cứu.
Beam Search — Giữ N candidate thay vì greedy, tìm output tốt hơn
Beam Search giữ N giả thuyết song song để tránh bẫy tối ưu cục bộ của greedy decoding, tìm câu trả lời toàn cục tốt hơn trong dịch máy và tạo code.