TROISINH
FoundationsNeural Network Primitives

BatchNorm vs LayerNorm — Tại sao Transformer chọn LayerNorm

BatchNorm thống trị CNN nhưng Transformer lại dùng LayerNorm. Hiểu vì sao normalization theo chiều batch thất bại với sequence, và cách LayerNorm giải cứu ổn định training.

Bạn đã từng thắc mắc tại sao ResNet và hầu hết CNN dùng BatchNorm, nhưng GPT, BERT và mọi Transformer hiện đại lại dùng LayerNorm? Không phải ngẫu nhiên. Đó là câu chuyện về "chiều nào chứa thông tin quan trọng" — và lý do khiến BatchNorm thất bại thảm hại khi gặp câu văn dài ngắn khác nhau.

Vấn đề

Khi train neural network sâu, hiện tượng internal covariate shift là kẻ thù vô hình: phân phối đầu vào của mỗi layer liên tục thay đổi vì layer trước đó vẫn đang học và cập nhật weights. Điều này khiến gradient bị nhiễu, training chậm, và dễ rơi vào vanishing/exploding gradients — như cố gắng đứng trên tàu đang chạy mà không có tay vịn.

BatchNorm ra đời năm 2015 để "bình ổn" điều này bằng cách chuẩn hóa dữ liệu theo thống kê của toàn bộ batch. Nhưng khi chuyển sang xử lý ngôn ngữ (NLP), chúng ta gặp ba vấn đề chí mạng:

  1. Batch size nhỏ: Train LLM thường dùng batch size 1–32 (do model lớn, sequence dài), trong khi BatchNorm cần batch 256+ để thống kê ổn định. Batch nhỏ khiến mean/variance nhảy loạn như ECG bệnh nhân đang hoảng loạn.
  2. Variable length: Câu văn có độ dài khác nhau. BatchNorm tính mean trên cả batch, nhưng khi pad ngắn để bằng dài, các token <pad> sẽ kéo mean về zero, làm sai lệch normalization.
  3. Autoregressive inference: Khi generate từng token, BatchNorm phải dùng "running statistics" từ training, nhưng distribution của text generated thay đổi liên tục, khiến stats cũ nhanh chóng lỗi thời.

Ý tưởng cốt lõi

BatchNorm chuẩn hóa theo chiều batch (dimension N). Với mỗi feature/channel, nó tính mean và variance trên toàn bộ các sample trong batch. Giả định ngầm: thống kê của batch đại diện cho toàn bộ population dữ liệu.

LayerNorm chuẩn hóa theo chiều feature (dimension C). Với mỗi sample riêng lẻ, nó tính mean và variance trên tất cả các features của sample đó. Giả định: các features bên trong một sample tự cân bằng lẫn nhau, không cần quan tâm sample khác đang làm gì.

Đây là "aha moment": LayerNorm làm mỗi sample độc lập tuyệt đối với batch. Dù batch có 1 câu hay 1000 câu, dù câu dài 10 token hay 10k token, cách normalize vẫn như nhau. Không còn vấn đề padding, không còn lo batch size nhỏ làm nhiễu thống kê.

Transformer chọn LayerNorm vì ba lý do cốt lõi:

  • Batch-agnostic: Train với batch 1 cũng ổn, phù hợp inference real-time.
  • Variable-length native: Mỗi token sequence tự normalize, padding không ảnh hưởng.
  • Ổn định residual: Khi kết hợp với Skip Connections, Pre-LayerNorm (đặt LayerNorm trước Attention/FFN) tạo "đường cao tốc" cho gradient, cho phép train model sâu hàng trăm layer mà không bị gradient vanish.

Tại sao nó hoạt động

Công thức toán học của cả hai đều là:
output = γ × (x - μ) / √(σ² + ε) + β

Khác biệt nằm ở chỗ tính μ và σ:

  • BatchNorm: μ_B = mean(x over N samples), σ_B = std(x over N samples). Nếu N=1 (batch size 1), μ_B chính là x, normalize ra zero vector → mất thông tin hoàn toàn. Nếu có padding tokens (giá trị 0), mean bị kéo về gần 0 dù dữ liệu thực không phải vậy.
  • LayerNorm: μ_L = mean(x over C features), σ_L = std(x over C features). Mỗi token (sample) tự chứa đủ thông tin để tính thống kê, không liên quan đến batch.

Trong Transformer, LayerNorm còn giải quyết vấn đề distribution drift khi generate. BatchNorm lưu running mean/variance từ training để dùng lúc inference, nhưng text generation là quá trình autoregressive — distribution của token thứ 100 rất khác token thứ 5. LayerNorm không cần lưu trữ gì cả, mỗi bước tính toán độc lập, nên ổn định suốt quá trình sinh văn bản dài.

Ý nghĩa thực tế

Đặc điểmBatchNorm (ResNet, CNN)LayerNorm (Transformer, RNN)
Chiều normalizeBatch (N)Features (C)
Yêu cầu batch sizeLớn (256+) để ổn địnhBất kỳ, từ 1 trở lên
Variable lengthKhó xử lý (padding làm méo mean)Tự nhiên, không ảnh hưởng
InferenceDùng running stats, risk driftTính real-time, không drift
Use caseImages (fixed size, large batch)Text/Sequence (variable, small batch)

Trong thực tế:

  • CNNs (ResNet, EfficientNet, ConvNeXt) dùng BatchNorm vì ảnh có kích thước cố định (224×224), batch size lớn (256–1024), và spatial statistics ổn định trên toàn batch.
  • Transformers (BERT, GPT, ViT, Llama) dùng LayerNorm vì sequence length biến động (128–128k tokens), batch size nhỏ (thường 1–32 do memory), và cần independence giữa các sample để dễ parallelize inference.

Biến thể hiện đại: RMSNorm (Root Mean Square Layer Normalization) được dùng trong Llama và nhiều LLM mới. Nó bỏ luôn bước tính mean (chỉ dùng RMS), giảm computation nhẹ nhưng vẫn giữ được tính batch-independent của LayerNorm. Đây là minh chứng rằng cộng đồng đã chấp nhận: với sequence modeling, normalization theo sample là "chân lý" không thể thay đổi.

Đào sâu hơn

Paper gốc:

  • Batch Normalization — Ioffe & Szegedy (ICML 2015): Khởi đầu kỷ nguyên normalization cho deep learning.
  • Layer Normalization — Ba et al. (arXiv 2016): Giải pháp cho RNN và sequence modeling.

Cùng cụm (nn-primitives):

Đọc tiếp:

  • Decoder-Only — Kiến trúc Transformer tại sao lại chọn LayerNorm thay vì BatchNorm trong kiến trúc autoregressive.
  • Training at Scale — Vấn đề batch size nhỏ và cách giải quyết khi train LLM, liên quan trực tiếp đến lý do LayerNorm thắng thế.

On this page