Backpropagation — Chain rule chạy ngược, đó là tất cả
Bản chất của việc dạy neural network: không phải ma thuật, chỉ là quy tắc chuỗi chạy ngược từ lỗi để điều chỉnh triệu tham số cùng lúc.
Mọi ứng dụng AI bạn dùng — từ ChatGPT tạo văn bản đến Midjourney vẽ tranh — đều dựa trên một thuật toán cũ kỹ từ năm 1986. Vấn đề là mạng neural hiện đại có hàng tỷ tham số. Làm sao để biết cần xoay nút nào, vặn gì để giảm lỗi? Thử từng nút một thì đến năm 3000 cũng chưa xong.
Vấn đề
Mạng neural là một hàm tổng hợp khổng lồ: input đi qua lớp này đến lớp khác, qua activation functions, cuối cùng ra dự đoán. So sánh dự đoán với đáp án đúng, bạn được một con số gọi là loss (lỗi). Nhiệm vụ là điều chỉnh hàng tỷ weight và bias để loss giảm xuống bằng gradient descent.
Nếu tính "thủ công", bạn phải thử thay đổi từng tham số một chút, chạy forward pass lại để xem loss thay đổi thế nào. Với 1 tỷ tham số, điều này đòi hỏi 1 tỷ lần chạy mạng — hoàn toàn không khả thi. Cần một cách để tính đạo hàm (gradient) của tất cả tham số chỉ trong một lần chạy duy nhất từ cuối về đầu.
Ý tưởng cốt lõi
Bản chất của backpropagation chỉ là quy tắc chuỗi (chain rule) từ giải tích cấp 2, nhưng áp dụng theo thứ tự ngược lại. Thay vì tính từ đầu vào đến đầu ra, ta tính từ lỗi truyền ngược về nguồn.
Hãy tưởng tượng một dây chuyền lỗi. Khi mạng dự đoán sai, tín hiệu lỗi bắt đầu từ lớp cuối cùng. Lớp đó hỏi: "Tôi đã đóng góp bao nhiêu vào lỗi này?" rồi truyền thông tin đó ngược lên lớp trước. Mỗi lớp tiếp theo làm tương tự, nhân thêm đạo hàm cục bộ của chính mình, rồi tiếp tục truyền lên. Đến khi tín hiệu đến weight ở lớp đầu tiên, nó đã mang theo đầy đủ thông tin: "Nếu thay đổi weight này, loss sẽ thay đổi bao nhiêu."
Cơ chế chi tiết hoạt động như sau:
-
Forward pass: Chạy mạng từ đầu đến cuối, tính toán và lưu trữ tất cả các giá trị trung gian (activations) tại mỗi lớp. Đây là "bản đồ" của toàn bộ quá trình tính toán.
-
Backward pass: Bắt đầu từ loss, tính gradient của loss theo output cuối cùng. Rồi sử dụng chain rule: gradient theo lớp trước bằng gradient hiện tại nhân với đạo hàm local (local gradient) của phép toán ở lớp đó.
-
Tái sử dụng tính toán: Khi đi ngược lên, mỗi node chỉ cần nhận "gradient từ sau truyền lên" và nhân với đạo hàm của chính nó. Kết quả được dùng ngay lập tức cho các nhánh khác nhau của đồ thị tính toán. Điều này tránh việc tính toán lặp lại — một dạng dynamic programming trên đồ thị.
Điểm mấu chốt khiến backpropagation nhanh đến vậy là nó tính toán song song trên cấu trúc đồ thị. Thay vì O(N) lần chạy mạng cho N tham số, nó chỉ cần O(N) phép tính đơn giản trong một lần backward duy nhất. Mỗi weight nhận được gradient chính xác như thể bạn đã tính riêng lẻ cho nó, nhưng với chi phí tính toán tương đương chỉ một lần forward pass.
That's it. Không có ma thuật. Chỉ là quy tắc chuỗi chạy ngược, tận dụng cấu trúc đồ thị để chia sẻ tính toán.
Tại sao nó hoạt động
Về mặt toán học, nếu là loss và là một weight ở lớp đầu, quan hệ giữa chúng là một chuỗi các hàm hợp qua các lớp trung gian :
Backpropagation tính tích này từ phải sang trái (ngược lại với forward). Tại sao hướng ngược lại hiệu quả hơn? Vì khi đi ngược, bạn tính một lần, rồi dùng nó cho tất cả các weight kết nối đến . Nếu đi xuôi, bạn phải tính đường đi riêng cho từng weight, dẫn đến số lượng phép tính bùng nổ.
Tuy nhiên, điều này đòi hỏi bộ nhớ lớn. Bạn phải lưu toàn bộ đồ thị tính toán (activations) từ forward pass để dùng trong backward. Đây là lý do training model lớn tốn VRAM gấp 2-3 lần so với inference — bạn không chỉ lưu weight, mà còn lưu "lịch sử" của mọi phép toán trung gian.
Ý nghĩa thực tế
Backpropagation là nền tảng của cuộc cách mạng deep learning hiện đại. Trước năm 1986, người ta chỉ huấn luyện được mạng nông 2-3 lớp vì không thể điều chỉnh hiệu quả các tham số sâu bên trong. Giờ đây, nhờ backprop, chúng ta huấn luyện được mạng 100+ lớp, tạo ra các mô hình như GPT-4.
Tuy nhiên, thuật toán này có những hạn chế quan trọng:
-
Vanishing gradients: Khi nhân nhiều đạo hàm nhỏ (< 1) qua hàng chục lớp, gradient ở lớp đầu tiên tiêu biến về 0, khiến các lớp đầu không học được. Đây là lý do residual connections và BatchNorm/LayerNorm ra đời — chúng tạo "đường cao tốc" cho gradient chạy ngược mà không bị suy giảm.
-
Memory wall: Với sequence dài (như trong sequence modeling), lưu activations cho cả chuỗi trở nên không khả thi. Các kỹ thuật như gradient checkpointing (trade compute for memory) trong training at scale giúp giải quyết vấn đề này bằng cách tính lại một số activations thay vì lưu chúng.
-
Không sinh học: Não người không tính toán gradient chính xác theo cách này. Backprop là một cơ chế "engineering" tinh khiết, không phải mô phỏng sinh học.
Mọi framework hiện đại (PyTorch, TensorFlow, JAX) đều implement autograd — tự động tính gradient bằng backpropagation. Bạn chỉ định nghĩa forward pass, framework tự xây dựng đồ thị tính toán và chạy backward khi gọi .backward().
Đào sâu hơn
- Paper gốc: Rumelhart, Hinton, Williams (1986) — "Learning representations by back-propagating errors" (tạp chí Nature), đặt nền móng cho việc huấn luyện mạng neural đa tầng.
Bài liên quan TroiSinh:
-
Cùng cụm (nn-primitives):
- Gradient Descent & Adam — Sau khi có gradient từ backprop, dùng nó để cập nhật weight thế nào cho hiệu quả.
- Activation Functions — Đạo hàm của ReLU, GELU, SiLU là gì và tại sao chúng quan trọng cho backward pass.
- Residual Connections — Cứu gradient khỏi chết khi chạy ngược qua hàng trăm lớp.
- BatchNorm vs LayerNorm — Cách normalization ổn định gradient flow và tăng tốc hội tụ.
- Weight Initialization — Khởi tạo đúng để gradient không bị nổ hoặc chết ngay từ epoch đầu.
-
Đọc tiếp:
- Sequence Modeling — Backpropagation Through Time (BPTT) và cách gradient chảy qua attention mechanism trong Transformer.
- Training at Scale — Cách backprop hoạt động khi chia nhỏ batch qua nhiều GPU (data parallelism) và các kỹ thuật tối ưu bộ nhớ như gradient accumulation.
Learning Rate Scheduling — Warmup rồi decay, tại sao cần cả hai
Tại sao training LLM tỷ parameter cần warmup để tránh nổ ngay từ bước đầu, và decay để không lãng phí triệu đô cuối quá trình. Giải thích intuition đằng sau cosine annealing.
Gradient Descent & Adam — Tại sao Adam thắng SGD
Adam không chỉ là 'SGD có momentum' — nó là việc mỗi tham số trong mạng neuron được học với tốc độ riêng. Hiểu bản chất adaptive learning rate qua analog chiếc xe địa hình.