Deep Learning
Attention Mechanisms
How scaled dot-product and multi-head attention work — the soft key-value lookup at the heart of every Transformer — with the math, runnable PyTorch, and calibrated interview questions.
Attention is the mechanism that lets a model decide, for each position in a sequence, which other positions matter — and by how much. It replaced the recurrent bottleneck of RNNs/LSTMs, where information from distant tokens had to survive being squeezed through a single fixed-size hidden state, step after step. Attention removes that bottleneck: every position can look directly at every other position in one operation.
This is the core primitive of the Transformer, and therefore of essentially every modern LLM. If you understand scaled dot-product attention and multi-head attention, you understand the load-bearing 20% of the architecture that interviewers probe hardest. The mental frame to hold: attention is a differentiable, soft dictionary lookup — a query retrieves a weighted blend of values, where the weights come from how well the query matches each key.
Think of a soft key-value store. Each token emits three vectors:
- a query (
q) — "what am I looking for?" - a key (
k) — "what do I offer?" - a value (
v) — "what I'll hand over if you attend to me."
For a given query, you score it against every key (dot product = similarity), turn those scores into a probability distribution with softmax, and return the weighted sum of the values. A hard dictionary returns exactly one value for an exact key match; attention returns a blend, weighted by match strength — which is what makes it differentiable and trainable.
flowchart LR
X["Input tokens"] --> Q["Queries Q"]
X --> K["Keys K"]
X --> V["Values V"]
Q --> S["scores = QKᵀ / √dₖ"]
K --> S
S --> A["softmax to weights"]
A --> O["output = weights · V"]
V --> O
"Self-attention" just means Q, K, and V are all projected from the same sequence — each token attends to the whole sequence, including itself.
Scaled dot-product attention, for queries , keys , and values :
The softmax is taken row-wise, so each query's attention weights sum to 1. Why divide by ? If the components of and are independent with mean 0 and variance 1, their dot product has variance . For large , the raw scores grow large in magnitude, pushing softmax into regions where its gradient is almost zero (it saturates toward a one-hot vector). Scaling by normalizes the variance back to ~1, keeping gradients healthy.
Multi-head attention runs attention operations in parallel on lower-dimensional projections, then concatenates:
Each head can specialize (one tracks syntax, another coreference, etc.). With , multi-head costs roughly the same as single-head of full width.
1import torch
2import torch.nn.functional as F
3from torch import Tensor, nn
4
5
6def scaled_dot_product_attention(
7 q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None
8) -> Tensor:
9 # q,k,v: (batch, heads, seq, d_k)
10 d_k = q.size(-1)
11 scores = (q @ k.transpose(-2, -1)) / d_k**0.5 # (b, h, seq, seq)
12 if mask is not None:
13 scores = scores.masked_fill(mask == 0, float("-inf")) # causal / padding
14 weights = F.softmax(scores, dim=-1)
15 return weights @ v # (b, h, seq, d_k)
16
17
18class MultiHeadAttention(nn.Module):
19 def __init__(self, d_model: int, n_heads: int) -> None:
20 super().__init__()
21 assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
22 self.n_heads, self.d_k = n_heads, d_model // n_heads
23 self.qkv = nn.Linear(d_model, 3 * d_model) # fused Q,K,V projection
24 self.out = nn.Linear(d_model, d_model)
25
26 def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
27 b, seq, _ = x.shape
28 qkv = self.qkv(x).view(b, seq, 3, self.n_heads, self.d_k)
29 q, k, v = qkv.permute(2, 0, 3, 1, 4) # each: (b, heads, seq, d_k)
30 out = scaled_dot_product_attention(q, k, v, mask)
31 out = out.transpose(1, 2).reshape(b, seq, -1) # recombine heads
32 return self.out(out)
33
34
35# quick shape test
36mha = MultiHeadAttention(d_model=512, n_heads=8)
37x = torch.randn(2, 10, 512)
38assert mha(x).shape == (2, 10, 512)- Conceptual: Why does attention scale scores by , and what fails if you don't? (Variance of the dot product grows with ; unscaled scores saturate softmax and kill gradients.)
- Implementation: How do you implement a causal mask, and where in the computation does it go? (Set future positions to in the score matrix before softmax, so their weights become 0.)
- Applied: What's the time and memory complexity of self-attention in sequence length , and why is that a problem for long contexts? ( time and memory from the score matrix — the motivation for FlashAttention, sparse, and linear-attention variants.)
- Systems-level: What does the KV cache do at inference, and why does it make autoregressive decoding far cheaper? (Caches past keys/values so each new token attends to stored K/V instead of recomputing them — turns per-token cost from quadratic-recompute into linear.)
- Failure modes: Why use multiple heads instead of one wide head? (Separate subspaces let heads attend to different relations simultaneously; a single softmax can only express one weighting per query.)
Close this page. Derive scaled dot-product attention from scratch: write the formula, state the shapes of , , , explain the term, and describe in one sentence how multi-head differs. Then check what you missed against Stage 3.
This is one static walkthrough. A live session goes further.
Ask follow-ups at interview depth, get the math and code rendered as you go, and run a retrieval drill until it sticks — then come back to the thread anytime.
Related concepts
Deep Learning
Transformer Architecture
The Transformer block from the ground up — self-attention plus a position-wise feed-forward network, residuals and LayerNorm, and the encoder/decoder configurations — with the math, PyTorch, and calibrated interview questions.
LLMs
KV Cache
How the KV cache makes autoregressive LLM decoding affordable — what it stores and why reuse is valid, the memory cost, why decoding is memory-bandwidth-bound, and how MQA/GQA shrink it — with code.
Deep Learning
Batch Normalization
What Batch Norm normalizes and why, the critical train-vs-inference distinction, BN vs. Layer Norm, with the math and a from-scratch PyTorch implementation.