Skip to content

Deep Learning

Attention Mechanisms

How scaled dot-product and multi-head attention work — the soft key-value lookup at the heart of every Transformer — with the math, runnable PyTorch, and calibrated interview questions.

9 min readReviewed May 2026

1Big Picture

Attention is the mechanism that lets a model decide, for each position in a sequence, which other positions matter — and by how much. It replaced the recurrent bottleneck of RNNs/LSTMs, where information from distant tokens had to survive being squeezed through a single fixed-size hidden state, step after step. Attention removes that bottleneck: every position can look directly at every other position in one operation.

This is the core primitive of the Transformer, and therefore of essentially every modern LLM. If you understand scaled dot-product attention and multi-head attention, you understand the load-bearing 20% of the architecture that interviewers probe hardest. The mental frame to hold: attention is a differentiable, soft dictionary lookup — a query retrieves a weighted blend of values, where the weights come from how well the query matches each key.

2Intuition + Visual

Think of a soft key-value store. Each token emits three vectors:

  • a query (q) — "what am I looking for?"
  • a key (k) — "what do I offer?"
  • a value (v) — "what I'll hand over if you attend to me."

For a given query, you score it against every key (dot product = similarity), turn those scores into a probability distribution with softmax, and return the weighted sum of the values. A hard dictionary returns exactly one value for an exact key match; attention returns a blend, weighted by match strength — which is what makes it differentiable and trainable.

flowchart LR
    X["Input tokens"] --> Q["Queries Q"]
    X --> K["Keys K"]
    X --> V["Values V"]
    Q --> S["scores = QKᵀ / √dₖ"]
    K --> S
    S --> A["softmax to weights"]
    A --> O["output = weights · V"]
    V --> O

"Self-attention" just means Q, K, and V are all projected from the same sequence — each token attends to the whole sequence, including itself.

3The Math

Scaled dot-product attention, for queries QRn×dkQ \in \mathbb{R}^{n \times d_k}, keys KRm×dkK \in \mathbb{R}^{m \times d_k}, and values VRm×dvV \in \mathbb{R}^{m \times d_v}:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

The softmax is taken row-wise, so each query's attention weights sum to 1. Why divide by dk\sqrt{d_k}? If the components of qq and kk are independent with mean 0 and variance 1, their dot product qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i has variance dkd_k. For large dkd_k, the raw scores grow large in magnitude, pushing softmax into regions where its gradient is almost zero (it saturates toward a one-hot vector). Scaling by dk\sqrt{d_k} normalizes the variance back to ~1, keeping gradients healthy.

Multi-head attention runs hh attention operations in parallel on lower-dimensional projections, then concatenates:

MultiHead(Q,K,V)=Concat(head1,,headh)WO,headi=Attention(QWiQ,KWiK,VWiV)\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)\, W^O, \quad \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

Each head can specialize (one tracks syntax, another coreference, etc.). With dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h, multi-head costs roughly the same as single-head of full width.

4Implementation
python
1import torch
2import torch.nn.functional as F
3from torch import Tensor, nn
4
5
6def scaled_dot_product_attention(
7    q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None
8) -> Tensor:
9    # q,k,v: (batch, heads, seq, d_k)
10    d_k = q.size(-1)
11    scores = (q @ k.transpose(-2, -1)) / d_k**0.5      # (b, h, seq, seq)
12    if mask is not None:
13        scores = scores.masked_fill(mask == 0, float("-inf"))  # causal / padding
14    weights = F.softmax(scores, dim=-1)
15    return weights @ v                                  # (b, h, seq, d_k)
16
17
18class MultiHeadAttention(nn.Module):
19    def __init__(self, d_model: int, n_heads: int) -> None:
20        super().__init__()
21        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
22        self.n_heads, self.d_k = n_heads, d_model // n_heads
23        self.qkv = nn.Linear(d_model, 3 * d_model)      # fused Q,K,V projection
24        self.out = nn.Linear(d_model, d_model)
25
26    def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
27        b, seq, _ = x.shape
28        qkv = self.qkv(x).view(b, seq, 3, self.n_heads, self.d_k)
29        q, k, v = qkv.permute(2, 0, 3, 1, 4)            # each: (b, heads, seq, d_k)
30        out = scaled_dot_product_attention(q, k, v, mask)
31        out = out.transpose(1, 2).reshape(b, seq, -1)   # recombine heads
32        return self.out(out)
33
34
35# quick shape test
36mha = MultiHeadAttention(d_model=512, n_heads=8)
37x = torch.randn(2, 10, 512)
38assert mha(x).shape == (2, 10, 512)
5Interview Questions
  1. Conceptual: Why does attention scale scores by dk\sqrt{d_k}, and what fails if you don't? (Variance of the dot product grows with dkd_k; unscaled scores saturate softmax and kill gradients.)
  2. Implementation: How do you implement a causal mask, and where in the computation does it go? (Set future positions to -\infty in the score matrix before softmax, so their weights become 0.)
  3. Applied: What's the time and memory complexity of self-attention in sequence length nn, and why is that a problem for long contexts? (O(n2d)O(n^2 d) time and O(n2)O(n^2) memory from the n×nn \times n score matrix — the motivation for FlashAttention, sparse, and linear-attention variants.)
  4. Systems-level: What does the KV cache do at inference, and why does it make autoregressive decoding far cheaper? (Caches past keys/values so each new token attends to stored K/V instead of recomputing them — turns per-token cost from quadratic-recompute into linear.)
  5. Failure modes: Why use multiple heads instead of one wide head? (Separate subspaces let heads attend to different relations simultaneously; a single softmax can only express one weighting per query.)
6Retrieval Check

Close this page. Derive scaled dot-product attention from scratch: write the formula, state the shapes of QQ, KK, VV, explain the dk\sqrt{d_k} term, and describe in one sentence how multi-head differs. Then check what you missed against Stage 3.

This is one static walkthrough. A live session goes further.

Ask follow-ups at interview depth, get the math and code rendered as you go, and run a retrieval drill until it sticks — then come back to the thread anytime.

Related concepts