TROISINH
FoundationsNeural Network Primitives

Gradient Descent & Adam — Tại sao Adam thắng SGD

Adam không chỉ là 'SGD có momentum' — nó là việc mỗi tham số trong mạng neuron được học với tốc độ riêng. Hiểu bản chất adaptive learning rate qua analog chiếc xe địa hình.

Training neural networks giống như việc tìm điểm thấp nhất của một dãy núi trong khi bị bịt mắt. Bạn chỉ cảm nhận được độ dốc dưới chân. SGD (Stochastic Gradient Descent) là cách đi bộ cố định: bước xuống theo hướng dốc nhất, dù đó là vực sâu hay đồng bằng. Adam là cách lái xe địa hình thông minh hơn — nó nhớ đà tăng tốc, tự điều chỉnh ga cho từng bánh xe, và biết khi nào cần phanh để không lao quá đà.

Vấn đề

SGD bị mắc kẹt với "one size fits all".

Khi bạn dùng SGD thuần túy, toàn bộ mạng neuron phải học với cùng một learning rate (tốc độ học). Điều này tạo ra ba vấn đề chết người:

  1. Thung lũng hẹp (Ill-conditioned loss landscape): Gradient chỉ chỉ hướng xuống dốc cục bộ, không phải hướng trực tiếp đến đích. SGD sẽ "zigzag" qua lại hai bên thung lũng thay vì đi thẳng, mất hàng trăm lần cập nhật để đến nơi. Với learning rate 0.01, bạn có thể dao động mãi không hội tụ; với 0.0001, bạn cần 100 lần lặp lại.

  2. Sparse gradients: Trong NLP, embedding của từ hiếm chỉ xuất hiện trong 0.1% số batch. Nếu learning rate quá nhỏ, trọng số này hầu như không thay đổi. Nếu quá lớc, các trọng số phổ biến (như "the", "and") sẽ bùng nổ.

  3. Momentum thiếu: SGD dừng lại ngay khi gradient bằng 0, ngay cả khi đó chỉ là "đồng bằng nhỏ" trước một đỉnh dốc khác. Nó không có khái niệm "đà" để lăn qua những chỗ bằng phẳng.

Điểm then chốt: Chúng ta cần một cơ chế cho phép mỗi tham số (weight) có tốc độ học riêng, và cần nhớ lịch sử gradient để biết khi nào nên đi nhanh, khi nào nên đi chậm.

Ý tưởng cốt lõi

Adam là sự kết hợp của hai insight: Momentum + Adaptive Learning Rates.

Hãy tưởng tượng bạn đang điều chỉnh 1000 chiếc nút vặn (knobs) trên một bảng điều khiển âm thanh. Một số nút cần xoay nhẹ (fine-tuning), một số cần xoay mạnh (hiệu chỉnh lớn). SGD giống như việc xoay tất cả các nút cùng một góc độ — dẫn đến nút này bị xoay quá tay, nút kia chưa đủ.

Insight thứ nhất — Momentum: Thay vì bước theo gradient hiện tại, Adam tích lũy "vận tốc" (velocity). Giống như quả bóng lăn xuống dốc — nó không dừng lại ngay khi mặt đất bằng phẳng một chút, mà tiếp tục lăn nhờ quán tính. Trong Adam, điều này được tính bằng trung bình động (moving average) của gradient với hệ số beta1 (thường 0.9). Nếu gradient liên tục chỉ cùng một hướng, vận tốc tăng dần; nếu gradient dao động, chúng triệt tiêu nhau.

Insight thứ hai — Adaptive (RMSprop-style): Mỗi tham số có learning rate riêng. Nếu một tham số thường xuyên nhận gradient lớn (biến động mạnh), Adam tự động giảm learning rate cho nó để tránh nhảy loạn. Nếu tham số hiếm khi thay đổi (sparse), Adam tăng learning rate để nó có cơ hội học khi xuất hiện. Điều này thực hiện bằng cách chia learning rate cho căn bậc hai của trung bình bình phương gradient (second moment).

That's it. Adam không phải là phép màu. Nó chỉ là: "Hãy nhớ hướng vừa đi (momentum), và đừng bước quá lớn vào những chỗ địa hình gập ghềnh (adaptive)".

Misconception phổ biến: Nhiều người nghĩ Adam "phức tạp" vì có nhiều hyperparameter. Thực ra Adam ít nhạy cảm với learning rate hơn SGD rất nhiều — bạn thường chỉ cần để learning rate mặc định 0.001 và nó tự điều chỉnh phần còn lại.

Tại sao nó hoạt động

Toán học đằng sau là việc duy trì hai bộ nhớ song song:

Adam lưu trữ hai trạng thái cho mỗi tham số:

  1. First moment (m): Trung bình động của gradient — đại diện cho "hướng trung bình" mà chúng ta nên đi.
  2. Second moment (v): Trung bình động của bình phương gradient — đại diện cho "độ gập ghềnh" của địa hình.

Công thức cập nhật (giải thích trước, code sau):

  • Tính gradient gtg_t tại bước tt
  • Cập nhật moment: mt=β1mt1+(1β1)gtm_t = \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t
  • Cập nhật second moment: vt=β2vt1+(1β2)gt2v_t = \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot g_t^2
  • Hiệu chỉnh bias (quan trọng): m^t=mt/(1β1t)\hat{m}_t = m_t / (1-\beta_1^t), v^t=vt/(1β2t)\hat{v}_t = v_t / (1-\beta_2^t) — điều này sửa lỗi bias về 0 trong các bước đầu.
  • Cập nhật tham số: θt=θt1αm^t/(v^t+ϵ)\theta_t = \theta_{t-1} - \alpha \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)

Tại sao chia cho \sqrt[v] lại hiệu quả? Nếu gradient của một tham số luôn lớn (v lớn), mẫu số lớn làm bước nhảy nhỏ lại — tránh dao động. Nếu gradient thưa thớt (v nhỏ), mẫu số nhỏ giúp bước nhảy lớn hơn — bù đắp cho việc hiếm được cập nhật.

# PyTorch-style pseudocode cho Adam update
m = beta1 * m + (1 - beta1) * g  # velocity
v = beta2 * v + (1 - beta2) * (g ** 2)  # variance
m_hat = m / (1 - beta1**t)  # bias correction
v_hat = v / (1 - beta2**t)
theta = theta - lr * m_hat / (v_hat.sqrt() + 1e-8)

Điểm then chốt: Adam tự động chuẩn hóa learning rate theo "scale" của gradient. Đây là lý do nó hoạt động tốt với sparse gradients trong NLP — embedding của từ "xylophone" (hiếm) sẽ có v nhỏ, nên khi xuất hiện nó được cập nhật mạnh; embedding của "the" có v lớn nên được cập nhật nhẹ nhàng.

Ý nghĩa thực tế

Adam trở thành optimizer mặc định cho hầu hết Deep Learning hiện đại không phải vì nó hoàn hảo, mà vì nó chịu đựng được sự bất cẩn của người dùng và hội tụ nhanh hơn đáng kể trong giai đoạn đầu training.

SGDAdam
Memory1x (chỉ weights)3x (weights + m + v)
HyperparameterRất nhạy cảm với LRÍt nhạy cảm, beta mặc định 0.9/0.999 hoạt động tốt
Hội tụChậm, dễ zigzagNhanh, ổn định hơn ở early training
Sparse gradientsTệ (cùng LR cho tất cả)Tốt (adaptive per-parameter)
GeneralizationThường tốt hơn nếu tune kỹĐôi khi overfit hơn (cần weight decay đặc biệt — AdamW)

Benchmarks thực tế:

  • Trong training Transformer (BERT, GPT), Adam thường hội tụ sau 3-5 epochs so với 15-20 epochs của SGD với cùng loss.
  • Adam chiếm ~2-3x bộ nhớ so với SGD (phải lưu m và v cho mỗi tham số), đây là lý do training model lớn cần sharding optimizer states (ZeRO).

Ai đang dùng:

  • Transformer models: BERT, GPT, T5, LLaMA đều dùng Adam hoặc AdamW (phiên bản có weight decay đúng cách).
  • Computer Vision: Một số kỹ sư vẫn thích SGD với Momentum cho ResNet vì generalization tốt hơn khi train đủ lâu, nhưng Adam vẫn phổ biến cho fine-tuning.

Limitations — Khi nào KHÔNG dùng Adam:

  • Adam không giải quyết được vanishing gradient — đó là vấn đề của activation function và architecture (Residual Connections mới giải quyết điều này).
  • AdamW vs Adam: Adam cũ có vấn đề với weight decay (L2 regularization). AdamW tách biệt weight decay khỏi gradient update, giúp generalization tốt hơn. Hầu hết code hiện đại dùng AdamW thay vì Adam nguyên bản.
  • Memory bottleneck: Với model 175B parameters, Adam cần lưu 525GB cho optimizer states (3x 175B). Đây là lý do cần Quantization hoặc sharding.

Đào sâu hơn

  • Paper gốc: "Adam: A Method for Stochastic Optimization" (Kingma & Ba, 2015) — arXiv:1412.6980
  • AdamW: "Decoupled Weight Decay Regularization" (Loshchilov & Hutter, 2019) — sửa lỗi weight decay trong Adam gốc, hiện là chuẩn vàng.

Cùng cụm (nn-primitives)

  • Backpropagation — Hiểu gradient tính ngược như thế nào trước khi optimize
  • Weight Initialization — Xavier & He Init: Nếu khởi tạo sai, Adam cũng không cứu được
  • Activation Functions — ReLU, GELU: Gradient chết (vanishing) là vấn đề khác với optimizer

Đọc tiếp

On this page