Attention Is All You Need (Transformer)¶
Authors: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin
Year: 2017 | Venue: NeurIPS
Link: arXiv:1706.03762
TL;DR¶
The paper introduces the Transformer, a sequence-to-sequence architecture that replaces recurrence entirely with self-attention. Every token attends to every other token in a single layer, weighted by learned compatibility scores. Multi-head attention runs several attention functions in parallel so the model captures different relational patterns (syntax, coreference, long-range agreement). Positional encodings inject ordering information because attention itself is permutation-invariant.
Why This Paper Matters¶
The Transformer is the architectural backbone of virtually every modern LLM — GPT, BERT, T5, LLaMA, Mistral, and beyond. Every serving optimization (FlashAttention, KV cache, speculative decoding) targets the same attention primitive. If you understand this paper deeply, you understand the foundation of everything that followed.
Key Concepts Explained Simply¶
Self-Attention: The Core Idea¶
Imagine you're reading a sentence and for every word, you ask: "Which other words should I pay attention to in order to understand this word?"
- Query (Q): "What am I looking for?"
- Key (K): "What do I contain?"
- Value (V): "What information do I give back?"
Each word creates a Q, K, and V vector. The attention score between word i and word j is the dot product of word i's query with word j's key. High dot product = high relevance.
Multi-Head Attention¶
Instead of running attention once, the model runs it h times in parallel (typically h=8 or h=16), each with different learned projections. One head might learn syntactic relationships, another might learn coreference, another might learn positional patterns. The outputs are concatenated and projected.
Positional Encoding¶
Attention is permutation-invariant — it doesn't know word order. The paper adds sinusoidal position signals to the input embeddings:
- Even dimensions: \(\sin(pos / 10000^{2i/d})\)
- Odd dimensions: \(\cos(pos / 10000^{2i/d})\)
This lets the model learn relative positions because any fixed offset can be represented as a linear function of the encoding.
The Math — Explained Step by Step¶
Scaled Dot-Product Attention¶
Breaking it down:
-
\(QK^\top\): Matrix of dot products between every query and every key. Shape: \([n \times n]\) where \(n\) is sequence length. Each entry measures how much token i should attend to token j.
-
\(\div \sqrt{d_k}\): Scaling factor. Without it, when \(d_k\) is large (e.g., 64), dot products grow large, pushing softmax into regions with tiny gradients. Dividing by \(\sqrt{d_k}\) keeps values in a "nice" range for softmax.
-
\(\mathrm{softmax}\): Converts raw scores into a probability distribution over positions. Each row sums to 1 — it's a weighted "where to look" distribution.
-
\(\times V\): Weighted sum of value vectors using the attention weights. The output for each position is a mixture of all value vectors, weighted by relevance.
Multi-Head Attention¶
where each head is:
Each head has its own learned projection matrices \(W_i^Q, W_i^K, W_i^V\) of shape \([d_{\text{model}} \times d_k]\). The final projection \(W^O\) maps the concatenated heads back to \(d_{\text{model}}\).
Complexity¶
- Time: \(O(n^2 \cdot d)\) — quadratic in sequence length
- Space: \(O(n^2)\) for the attention matrix
- This quadratic cost is the bottleneck that FlashAttention, sparse attention, and linear attention methods target
Python Implementation¶
import numpy as np
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [seq_len_q, d_k]
K: [seq_len_k, d_k]
V: [seq_len_k, d_v]
mask: [seq_len_q, seq_len_k] — 0 where attention is blocked
Returns: [seq_len_q, d_v]
"""
d_k = Q.shape[-1]
scores = (Q @ K.T) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask == 0, -1e9, scores)
# Stable softmax
scores = scores - np.max(scores, axis=-1, keepdims=True)
attn_weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)
return attn_weights @ V, attn_weights
def multi_head_attention(Q, K, V, W_q, W_k, W_v, W_o, n_heads, mask=None):
"""
W_q, W_k, W_v: [d_model, d_model] (split across heads internally)
W_o: [d_model, d_model]
"""
d_model = Q.shape[-1]
d_k = d_model // n_heads
Q_proj = Q @ W_q
K_proj = K @ W_k
V_proj = V @ W_v
heads = []
for i in range(n_heads):
start, end = i * d_k, (i + 1) * d_k
head_out, _ = scaled_dot_product_attention(
Q_proj[:, start:end],
K_proj[:, start:end],
V_proj[:, start:end],
mask
)
heads.append(head_out)
concat = np.concatenate(heads, axis=-1)
return concat @ W_o
def positional_encoding(seq_len, d_model):
"""Sinusoidal positional encoding from the paper."""
pe = np.zeros((seq_len, d_model))
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
def causal_mask(seq_len):
"""Lower-triangular mask for autoregressive (decoder) attention."""
return np.tril(np.ones((seq_len, seq_len)))
def transformer_block(x, W_q, W_k, W_v, W_o, W1, b1, W2, b2,
gamma1, beta1, gamma2, beta2, n_heads, mask=None):
"""Single Transformer encoder block: MHA + Add&Norm + FFN + Add&Norm."""
# Multi-head self-attention
attn_out = multi_head_attention(x, x, x, W_q, W_k, W_v, W_o, n_heads, mask)
x = layer_norm(x + attn_out, gamma1, beta1)
# Position-wise FFN: Linear -> ReLU -> Linear
ffn_out = np.maximum(0, x @ W1 + b1) @ W2 + b2
x = layer_norm(x + ffn_out, gamma2, beta2)
return x
def layer_norm(x, gamma, beta, eps=1e-6):
mean = np.mean(x, axis=-1, keepdims=True)
var = np.var(x, axis=-1, keepdims=True)
return gamma * (x - mean) / np.sqrt(var + eps) + beta
# --- Demo ---
if __name__ == "__main__":
np.random.seed(42)
seq_len, d_model, n_heads = 6, 16, 4
x = np.random.randn(seq_len, d_model)
x = x + positional_encoding(seq_len, d_model)
Q = K = V = x
output, weights = scaled_dot_product_attention(Q, K, V)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("Attention weights (row 0):", np.round(weights[0], 3))
print("Weights sum per row:", np.round(weights.sum(axis=-1), 6))
# With causal mask (decoder-style)
mask = causal_mask(seq_len)
output_causal, weights_causal = scaled_dot_product_attention(Q, K, V, mask)
print("\nCausal attention weights (row 3):", np.round(weights_causal[3], 3))
Interview Importance¶
This is the #1 most asked paper in LLM interviews. Every concept in modern LLMs traces back here.
Difficulty Level: ⭐⭐⭐ (Medium-High)¶
Interviewers expect you to be able to write the attention formula on a whiteboard, explain every term, and discuss practical implications.
Interview Questions & Answers¶
Q1: Walk through the attention formula. Why divide by \(\sqrt{d_k}\)?¶
Answer: The dot product \(q \cdot k\) between two random vectors with \(d_k\) independent components has variance proportional to \(d_k\). When \(d_k = 64\), raw dot products can easily reach values like ±16, which pushes softmax into saturation — outputs near 0 or 1 with near-zero gradients. Dividing by \(\sqrt{d_k}\) normalizes the variance to approximately 1, keeping softmax in its informative gradient regime.
Q2: What is the difference between self-attention in an encoder vs. a decoder?¶
Answer: In an encoder, each token attends to all tokens (bidirectional). In a decoder, each token can only attend to itself and previous tokens — future tokens are masked with \(-\infty\) before softmax. This causal masking ensures the model can't "peek ahead," which is essential for autoregressive generation.
Q3: Why is attention \(O(n^2)\) and what are the consequences?¶
Answer: The attention matrix is \(n \times n\) (every token vs. every token). For a 128K context window, that's ~16 billion entries per layer per head. This drives memory costs (storing the matrix) and compute costs (matrix multiplications). Solutions include sparse attention, sliding window attention, linear attention approximations, and FlashAttention (which is exact but IO-aware to avoid materializing the full matrix in HBM).
Q4: Why use multi-head attention instead of single-head with larger dimension?¶
Answer: Different heads can learn to attend to different types of relationships simultaneously. Empirically, one head might specialize in local syntactic patterns while another captures long-range dependencies. A single head with the same total dimension would be forced to compress all these patterns into one set of weights. Multi-head also doesn't cost more compute — the per-head dimension is \(d_k = d_{\text{model}}/h\), so total FLOPs are equivalent.
Q5: How do positional encodings work and why are they needed?¶
Answer: Self-attention is permutation-equivariant — shuffling the input tokens and then running attention gives the same result as running attention then shuffling outputs. Without position signals, the model can't distinguish "the cat sat on the mat" from "mat the on sat cat the." Sinusoidal encodings add unique position-dependent signals. Later models replaced these with learned positional embeddings or Rotary Position Embeddings (RoPE), which encode relative positions directly into the attention computation.
Q6: In the encoder-decoder architecture, how does cross-attention work?¶
Answer: In cross-attention, queries come from the decoder and keys/values come from the encoder output. This lets each decoder position "look at" the entire encoded input to decide what source information is relevant for the current generation step. It's the mechanism that connects the "understanding" (encoder) with the "generation" (decoder) in tasks like translation.
Connections to Other Papers¶
- BERT → Uses the encoder side with bidirectional attention
- GPT-2/3 → Uses the decoder side with causal masking
- T5 → Uses the full encoder-decoder architecture
- FlashAttention → Optimizes the exact same computation for GPU memory hierarchy
- Mamba/SSMs → Alternative to attention with linear complexity
Key Takeaways for Quick Review¶
| Concept | Remember |
|---|---|
| Core formula | \(\mathrm{softmax}(QK^\top / \sqrt{d_k}) \cdot V\) |
| Scaling reason | Prevent softmax saturation with large \(d_k\) |
| Multi-head purpose | Parallel attention patterns (syntax, semantics, position) |
| Positional encoding | Sinusoidal — needed because attention is permutation-invariant |
| Complexity | \(O(n^2 d)\) time, \(O(n^2)\) space |
| Architecture | Encoder-decoder with 6 layers each in original paper |