Deep Learning
Batch Normalization
What Batch Norm normalizes and why, the critical train-vs-inference distinction, BN vs. Layer Norm, with the math and a from-scratch PyTorch implementation.
Batch Normalization normalizes the activations of a layer across the current mini-batch, so each feature has roughly zero mean and unit variance before it flows to the next layer. Introduced by Ioffe & Szegedy (2015), it let practitioners train much deeper networks, use higher learning rates, and worry less about weight initialization.
The original motivation was "internal covariate shift" — the idea that the distribution of each layer's inputs keeps shifting as earlier layers update. That explanation is now contested (Santurkar et al. showed BN mainly smooths the loss landscape), but the empirical payoff is not: BN reliably speeds up and stabilizes training. The frame to hold: BN is a per-feature, per-batch standardization with two learnable parameters that let the network undo the normalization if it needs to. Knowing the train-vs-eval distinction cold is the single most common thing interviewers check.
For a batch of activations shaped (N, C) — N examples, C features — BN computes the mean and variance down each column (across the N examples) and standardizes that feature. Every feature ends up centered and scaled, regardless of what the previous layer's weights are doing. Then two learnable parameters, and , rescale and re-shift — so if the network is better off not normalizing a feature, it can learn , and recover the original.
flowchart TB
A["Batch activations (N × C)"] --> B["per-feature mean μ_B, var σ²_B (over N)"]
B --> C["normalize: x̂ = (x − μ_B)/√(σ²_B + ε)"]
C --> D["scale + shift: y = γ·x̂ + β"]
D --> E["to next layer"]
Contrast with Layer Norm, which normalizes across the features of a single example (along C, per row) — independent of batch size, which is why Transformers use LN, not BN.
For a mini-batch (per feature), BN computes:
is a small constant for numerical stability. and are learned per feature.
Train vs. eval — the critical distinction. At training time, and are computed from the current batch. The layer also maintains running estimates via an exponential moving average:
At inference, BN uses these fixed running statistics, not the batch — so a single example (or any batch size) produces deterministic outputs. Forgetting to switch to eval mode (model.eval()) is a classic bug: predictions then depend on whatever else is in the batch.
1import torch
2from torch import Tensor, nn
3
4
5class BatchNorm1d(nn.Module):
6 def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1) -> None:
7 super().__init__()
8 self.eps, self.momentum = eps, momentum
9 self.gamma = nn.Parameter(torch.ones(num_features)) # learnable scale
10 self.beta = nn.Parameter(torch.zeros(num_features)) # learnable shift
11 self.register_buffer("running_mean", torch.zeros(num_features))
12 self.register_buffer("running_var", torch.ones(num_features))
13
14 def forward(self, x: Tensor) -> Tensor: # x: (N, num_features)
15 if self.training:
16 mean = x.mean(dim=0)
17 var = x.var(dim=0, unbiased=False) # biased var, as in BN
18 with torch.no_grad(): # update running stats
19 self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
20 self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
21 else:
22 mean, var = self.running_mean, self.running_var # frozen at inference
23 x_hat = (x - mean) / torch.sqrt(var + self.eps)
24 return self.gamma * x_hat + self.beta
25
26
27bn = BatchNorm1d(16)
28out = bn(torch.randn(32, 16))
29assert out.shape == (32, 16)
30bn.eval() # switch to running stats
31assert bn(torch.randn(1, 16)).shape == (1, 16) # works for batch size 1- Conceptual: What does Batch Norm actually normalize, and along which axis? (Each feature, across the examples in the batch — per-column for
(N, C)input.) - Implementation: What changes between training and inference, and why? (Train uses batch stats + updates running averages; eval uses frozen running stats so output is independent of batch composition.)
- Applied: Why does BN struggle with very small batch sizes, and what would you use instead? (Batch stats become noisy/unreliable; use Group Norm or Layer Norm, which don't depend on batch size.)
- Systems-level: Why do Transformers use Layer Norm rather than Batch Norm? (LN normalizes per-example across features — independent of batch size and sequence length, and well-behaved for variable-length sequences and small/streaming batches.)
- Failure modes: Your model trains well but predictions are unstable in production. What's a likely BN-related cause? (Forgot
model.eval(), so BN uses batch stats at inference — outputs depend on what else is batched together.)
Without looking: write the four BN equations (mean, variance, normalize, scale-shift), name what and are for, and explain in one sentence exactly what differs at inference time. Then compare against Stage 3.
This is one static walkthrough. A live session goes further.
Ask follow-ups at interview depth, get the math and code rendered as you go, and run a retrieval drill until it sticks — then come back to the thread anytime.
Related concepts
Deep Learning
Backpropagation
Backpropagation as reverse-mode autodiff — the chain rule over the computational graph, the gradients for a linear layer and ReLU, and why gradients vanish — with a runnable manual backward pass.
Optimization
Gradient Descent (SGD, Momentum, Adam)
SGD, momentum, and Adam explained — the update rules, why mini-batching wins, Adam's bias correction, and when plain SGD generalizes better — with from-scratch implementations.
Deep Learning
Transformer Architecture
The Transformer block from the ground up — self-attention plus a position-wise feed-forward network, residuals and LayerNorm, and the encoder/decoder configurations — with the math, PyTorch, and calibrated interview questions.