Optimization
Gradient Descent (SGD, Momentum, Adam)
SGD, momentum, and Adam explained — the update rules, why mini-batching wins, Adam's bias correction, and when plain SGD generalizes better — with from-scratch implementations.
Gradient descent is how neural networks learn: iteratively nudge the parameters in the direction that most reduces the loss, where "direction" is the negative gradient. Pure (batch) gradient descent uses the whole dataset per step — accurate but slow and memory-hungry. Stochastic gradient descent (SGD) uses one example, and mini-batch SGD uses a small batch — the universal default, because it trades a little gradient noise for huge speed and good generalization.
On top of plain SGD sit two ideas every interviewer expects you to know: momentum (accumulate a velocity so you power through flat regions and damp oscillations) and adaptive methods like Adam (a per-parameter learning rate plus momentum). The frame to hold: they all share the same skeleton — estimate the gradient on a batch, then take a step — and differ only in how they shape the step from the gradient history.
Imagine a ball rolling on the loss surface. Plain SGD takes a fixed step straight downhill at each point — it zig-zags across narrow valleys and crawls across plateaus. Momentum gives the ball mass: it builds up velocity in consistent directions and cancels back-and-forth oscillation. Adam additionally gives each parameter its own step size, scaled down for parameters with large, noisy gradients and up for those with small, steady ones.
flowchart LR
L["Loss L(θ)"] --> G["Gradient ∇L on a mini-batch"]
G --> M["Momentum: smooth with past gradients"]
M --> A["Adam: per-parameter adaptive step"]
A --> U["Update θ ← θ − step"]
U --> L
Mini-batching is the practical sweet spot: batches of 32–1024 give a low-variance gradient estimate, use hardware efficiently, and the residual noise actually helps escape sharp minima.
SGD with learning rate , on a mini-batch gradient :
SGD with momentum (): maintain a velocity that is an exponentially weighted sum of past gradients:
Adam combines a first moment (mean, like momentum) and second moment (uncentered variance) with bias correction:
The bias correction matters early in training, when are still biased toward their zero initialization. Typical defaults: , , . AdamW decouples weight decay from the gradient step and is the standard for training Transformers.
1import torch
2from torch import Tensor
3
4
5def sgd_momentum_step(p: Tensor, grad: Tensor, v: Tensor, lr=0.01, beta=0.9) -> Tensor:
6 v.mul_(beta).add_(grad) # v = beta*v + grad
7 p.add_(v, alpha=-lr) # p = p - lr*v
8 return v
9
10
11class Adam:
12 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
13 self.params = list(params)
14 self.lr, (self.b1, self.b2), self.eps = lr, betas, eps
15 self.m = [torch.zeros_like(p) for p in self.params]
16 self.v = [torch.zeros_like(p) for p in self.params]
17 self.t = 0
18
19 @torch.no_grad()
20 def step(self):
21 self.t += 1
22 for i, p in enumerate(self.params):
23 g = p.grad
24 self.m[i].mul_(self.b1).add_(g, alpha=1 - self.b1)
25 self.v[i].mul_(self.b2).addcmul_(g, g, value=1 - self.b2)
26 m_hat = self.m[i] / (1 - self.b1**self.t) # bias correction
27 v_hat = self.v[i] / (1 - self.b2**self.t)
28 p.addcdiv_(m_hat, v_hat.sqrt().add_(self.eps), value=-self.lr)
29
30
31w = torch.randn(4, requires_grad=True)
32opt = Adam([w])
33(w.pow(2).sum()).backward() # toy loss = ||w||²
34opt.step()- Conceptual: Why use mini-batch SGD instead of full-batch gradient descent? (Far cheaper per step, uses hardware efficiently, and the gradient noise improves generalization and helps escape sharp minima.)
- Implementation: What does momentum add to plain SGD? (A velocity term — an EMA of past gradients — that accelerates consistent directions and damps oscillation in narrow valleys.)
- Applied: What two statistics does Adam track, and what does each buy you? (First moment = momentum; second moment = per-parameter adaptive scaling of the step.)
- Systems-level: Why is bias correction needed in Adam? (m and v start at zero, so early estimates are biased low; dividing by 1−βᵗ corrects this so early steps aren't too small.)
- Failure modes: When might SGD-with-momentum generalize better than Adam? (In large-scale vision/CNN training, tuned SGD+momentum often finds flatter minima with better test accuracy; Adam can overfit or converge to sharper minima.)
Without looking: write the update rule for SGD, SGD+momentum, and Adam. State what β₁ and β₂ control in Adam and why bias correction exists. Then name one case where SGD beats Adam. Compare against Stage 3.
This is one static walkthrough. A live session goes further.
Ask follow-ups at interview depth, get the math and code rendered as you go, and run a retrieval drill until it sticks — then come back to the thread anytime.
Related concepts
Deep Learning
Backpropagation
Backpropagation as reverse-mode autodiff — the chain rule over the computational graph, the gradients for a linear layer and ReLU, and why gradients vanish — with a runnable manual backward pass.
Classical ML
Bias-Variance Tradeoff
The exact decomposition of expected error into bias, variance, and irreducible noise — how to diagnose under- vs. overfitting, with intuition, math, and a runnable demo.
Deep Learning
Batch Normalization
What Batch Norm normalizes and why, the critical train-vs-inference distinction, BN vs. Layer Norm, with the math and a from-scratch PyTorch implementation.