Skip to content

Deep Learning

Batch Normalization

What Batch Norm normalizes and why, the critical train-vs-inference distinction, BN vs. Layer Norm, with the math and a from-scratch PyTorch implementation.

8 min readReviewed May 2026

1Big Picture

Batch Normalization normalizes the activations of a layer across the current mini-batch, so each feature has roughly zero mean and unit variance before it flows to the next layer. Introduced by Ioffe & Szegedy (2015), it let practitioners train much deeper networks, use higher learning rates, and worry less about weight initialization.

The original motivation was "internal covariate shift" — the idea that the distribution of each layer's inputs keeps shifting as earlier layers update. That explanation is now contested (Santurkar et al. showed BN mainly smooths the loss landscape), but the empirical payoff is not: BN reliably speeds up and stabilizes training. The frame to hold: BN is a per-feature, per-batch standardization with two learnable parameters that let the network undo the normalization if it needs to. Knowing the train-vs-eval distinction cold is the single most common thing interviewers check.

2Intuition + Visual

For a batch of activations shaped (N, C) — N examples, C features — BN computes the mean and variance down each column (across the N examples) and standardizes that feature. Every feature ends up centered and scaled, regardless of what the previous layer's weights are doing. Then two learnable parameters, γ\gamma and β\beta, rescale and re-shift — so if the network is better off not normalizing a feature, it can learn γ=Var\gamma = \sqrt{\text{Var}}, β=mean\beta = \text{mean} and recover the original.

flowchart TB
    A["Batch activations (N × C)"] --> B["per-feature mean μ_B, var σ²_B (over N)"]
    B --> C["normalize: x̂ = (x − μ_B)/√(σ²_B + ε)"]
    C --> D["scale + shift: y = γ·x̂ + β"]
    D --> E["to next layer"]

Contrast with Layer Norm, which normalizes across the features of a single example (along C, per row) — independent of batch size, which is why Transformers use LN, not BN.

3The Math

For a mini-batch B={x1,,xN}\mathcal{B} = \{x_1, \dots, x_N\} (per feature), BN computes:

μB=1Ni=1Nxi,σB2=1Ni=1N(xiμB)2\mu_\mathcal{B} = \frac{1}{N}\sum_{i=1}^{N} x_i, \qquad \sigma_\mathcal{B}^2 = \frac{1}{N}\sum_{i=1}^{N} (x_i - \mu_\mathcal{B})^2 x^i=xiμBσB2+ϵ,yi=γx^i+β\hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}}, \qquad y_i = \gamma\,\hat{x}_i + \beta

ϵ\epsilon is a small constant for numerical stability. γ\gamma and β\beta are learned per feature.

Train vs. eval — the critical distinction. At training time, μB\mu_\mathcal{B} and σB2\sigma_\mathcal{B}^2 are computed from the current batch. The layer also maintains running estimates via an exponential moving average:

μrun(1m)μrun+mμB\mu_{\text{run}} \leftarrow (1-m)\,\mu_{\text{run}} + m\,\mu_\mathcal{B}

At inference, BN uses these fixed running statistics, not the batch — so a single example (or any batch size) produces deterministic outputs. Forgetting to switch to eval mode (model.eval()) is a classic bug: predictions then depend on whatever else is in the batch.

4Implementation
python
1import torch
2from torch import Tensor, nn
3
4
5class BatchNorm1d(nn.Module):
6    def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1) -> None:
7        super().__init__()
8        self.eps, self.momentum = eps, momentum
9        self.gamma = nn.Parameter(torch.ones(num_features))   # learnable scale
10        self.beta = nn.Parameter(torch.zeros(num_features))   # learnable shift
11        self.register_buffer("running_mean", torch.zeros(num_features))
12        self.register_buffer("running_var", torch.ones(num_features))
13
14    def forward(self, x: Tensor) -> Tensor:  # x: (N, num_features)
15        if self.training:
16            mean = x.mean(dim=0)
17            var = x.var(dim=0, unbiased=False)               # biased var, as in BN
18            with torch.no_grad():                            # update running stats
19                self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
20                self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
21        else:
22            mean, var = self.running_mean, self.running_var  # frozen at inference
23        x_hat = (x - mean) / torch.sqrt(var + self.eps)
24        return self.gamma * x_hat + self.beta
25
26
27bn = BatchNorm1d(16)
28out = bn(torch.randn(32, 16))
29assert out.shape == (32, 16)
30bn.eval()                                                    # switch to running stats
31assert bn(torch.randn(1, 16)).shape == (1, 16)              # works for batch size 1
5Interview Questions
  1. Conceptual: What does Batch Norm actually normalize, and along which axis? (Each feature, across the examples in the batch — per-column for (N, C) input.)
  2. Implementation: What changes between training and inference, and why? (Train uses batch stats + updates running averages; eval uses frozen running stats so output is independent of batch composition.)
  3. Applied: Why does BN struggle with very small batch sizes, and what would you use instead? (Batch stats become noisy/unreliable; use Group Norm or Layer Norm, which don't depend on batch size.)
  4. Systems-level: Why do Transformers use Layer Norm rather than Batch Norm? (LN normalizes per-example across features — independent of batch size and sequence length, and well-behaved for variable-length sequences and small/streaming batches.)
  5. Failure modes: Your model trains well but predictions are unstable in production. What's a likely BN-related cause? (Forgot model.eval(), so BN uses batch stats at inference — outputs depend on what else is batched together.)
6Retrieval Check

Without looking: write the four BN equations (mean, variance, normalize, scale-shift), name what γ\gamma and β\beta are for, and explain in one sentence exactly what differs at inference time. Then compare against Stage 3.

This is one static walkthrough. A live session goes further.

Ask follow-ups at interview depth, get the math and code rendered as you go, and run a retrieval drill until it sticks — then come back to the thread anytime.

Related concepts