BERT: Pre-training of Deep Bidirectional Transformers¶
Authors: Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
Year: 2018 | Venue: NAACL
Link: arXiv:1810.04805
TL;DR¶
BERT pre-trains a bidirectional Transformer encoder using Masked Language Modeling (MLM): randomly mask 15% of input tokens and predict them using both left and right context. A second objective, Next Sentence Prediction (NSP), teaches the model to understand inter-sentence relationships. Fine-tuning replaces only a thin task-specific head on top, yielding strong results on classification, NER, QA, and more.
Why This Paper Matters¶
BERT demonstrated that pre-training + fine-tuning is wildly effective. Before BERT, NLP models were typically trained from scratch per task. BERT changed the paradigm: pre-train once on a large corpus, fine-tune cheaply on any downstream task. BERT-style encoders still power retrieval models, rerankers, classification, and NER in production systems today. The MLM objective is fundamentally different from GPT's causal LM, and interviewers love contrasting the two.
Key Concepts Explained Simply¶
Masked Language Modeling (MLM)¶
Think of it as a fill-in-the-blank exercise. Given "The cat [MASK] on the mat," the model must predict "sat" using context from both sides. This is what makes BERT bidirectional — unlike GPT, which can only look left.
The masking strategy has three cases (for the 15% of tokens selected):
- 80% of the time: Replace with [MASK]
- 10% of the time: Replace with a random token
- 10% of the time: Keep the original token
This prevents the model from learning a shortcut where it only pays attention to [MASK] tokens.
Next Sentence Prediction (NSP)¶
Given two sentences A and B, predict whether B actually follows A in the corpus (IsNext) or is randomly sampled (NotNext). This was designed to help tasks like question answering where understanding sentence relationships matters.
Important: NSP was later shown to be unnecessary or even harmful. RoBERTa dropped it entirely and got better results.
Fine-Tuning¶
For downstream tasks, you add a simple output layer:
- Classification: Take the [CLS] token representation → linear layer → softmax
- Token-level tasks (NER): Take each token's representation → linear layer per token
- QA: Predict start and end positions of the answer span
The Math — Explained Step by Step¶
MLM Loss¶
Breaking it down:
- \(\tilde{\mathbf{x}}\): The corrupted input (with masks applied)
- \(\mathcal{M}\): The set of masked positions (about 15% of tokens)
- \(P_\theta(x_i \mid \tilde{\mathbf{x}})\): The model's predicted probability for the original token at position \(i\), given the corrupted input
- The loss only computes over masked positions — unmasked tokens don't contribute gradients to the MLM head
The key insight: the model sees the entire corrupted sequence (left and right context) when predicting each masked token. This is fundamentally different from GPT's left-to-right factorization.
NSP Loss¶
Standard binary cross-entropy. \(y=1\) when B follows A, \(y=0\) when B is random.
Total Pre-training Loss¶
Python Implementation¶
import numpy as np
def stable_softmax(logits):
z = logits - np.max(logits, axis=-1, keepdims=True)
e = np.exp(z)
return e / np.sum(e, axis=-1, keepdims=True)
def mlm_mask_tokens(token_ids, vocab_size, mask_token_id,
mask_prob=0.15, seed=None):
"""
Apply BERT-style masking: 80% [MASK], 10% random, 10% keep.
Returns masked input and positions that were masked.
"""
if seed is not None:
np.random.seed(seed)
masked = token_ids.copy()
mask_positions = []
for i in range(len(token_ids)):
if np.random.random() < mask_prob:
mask_positions.append(i)
r = np.random.random()
if r < 0.8:
masked[i] = mask_token_id
elif r < 0.9:
masked[i] = np.random.randint(0, vocab_size)
# else: keep original (10%)
return masked, mask_positions
def mlm_loss(logits_at_masked, true_labels):
"""
Cross-entropy loss at masked positions only.
logits_at_masked: [num_masked, vocab_size]
true_labels: [num_masked] — true token IDs
"""
probs = stable_softmax(logits_at_masked)
loss = 0.0
for i, label in enumerate(true_labels):
loss -= np.log(probs[i, label] + 1e-12)
return loss / len(true_labels)
def nsp_loss(cls_logits, is_next):
"""
Binary cross-entropy for Next Sentence Prediction.
cls_logits: [2] — logits for [IsNext, NotNext]
is_next: 0 or 1
"""
probs = stable_softmax(cls_logits)
return -np.log(probs[is_next] + 1e-12)
def bert_embedding(token_ids, segment_ids, position_ids,
token_emb, segment_emb, position_emb):
"""
BERT input = token embedding + segment embedding + position embedding.
"""
return token_emb[token_ids] + segment_emb[segment_ids] + position_emb[position_ids]
class SimpleBERTDemo:
"""Minimal BERT forward pass for understanding the architecture."""
def __init__(self, vocab_size=1000, d_model=64, n_heads=4, max_len=128):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.token_emb = np.random.randn(vocab_size, d_model) * 0.02
self.segment_emb = np.random.randn(2, d_model) * 0.02
self.position_emb = np.random.randn(max_len, d_model) * 0.02
self.mlm_head = np.random.randn(d_model, vocab_size) * 0.02
self.nsp_head = np.random.randn(d_model, 2) * 0.02
def forward(self, token_ids, segment_ids, mask_positions):
seq_len = len(token_ids)
position_ids = np.arange(seq_len)
x = bert_embedding(
token_ids, segment_ids, position_ids,
self.token_emb, self.segment_emb, self.position_emb
)
# MLM logits at masked positions
mlm_logits = x[mask_positions] @ self.mlm_head
# NSP logits from [CLS] (position 0)
nsp_logits = x[0] @ self.nsp_head
return mlm_logits, nsp_logits
# --- Demo ---
if __name__ == "__main__":
np.random.seed(42)
vocab_size = 1000
mask_token_id = 999
tokens = np.array([101, 45, 200, 67, 88, 102, 33, 55, 102])
segments = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1])
masked_tokens, mask_pos = mlm_mask_tokens(
tokens, vocab_size, mask_token_id, seed=42
)
print(f"Original: {tokens}")
print(f"Masked: {masked_tokens}")
print(f"Mask positions: {mask_pos}")
model = SimpleBERTDemo(vocab_size=vocab_size)
mlm_logits, nsp_logits = model.forward(masked_tokens, segments, mask_pos)
true_labels = tokens[mask_pos]
loss_mlm = mlm_loss(mlm_logits, true_labels)
loss_nsp = nsp_loss(nsp_logits, is_next=1)
print(f"\nMLM loss: {loss_mlm:.4f}")
print(f"NSP loss: {loss_nsp:.4f}")
print(f"Total loss: {loss_mlm + loss_nsp:.4f}")
Interview Importance¶
BERT is a top-5 most-asked paper. It's the canonical encoder-only model and the contrast point against GPT-style decoders.
Difficulty Level: ⭐⭐⭐ (Medium)¶
Interview Questions & Answers¶
Q1: Explain MLM vs. causal language modeling. What are the trade-offs?¶
Answer: MLM masks random tokens and predicts them using bidirectional context — the model sees both left and right. Causal LM (GPT) predicts each token using only left context.
Trade-offs: - MLM produces better representations for understanding tasks (classification, retrieval) because every position has full context - Causal LM is naturally suited for generation — you can sample token-by-token - MLM can't easily generate text because masked positions are independent of each other (the model assumes masked tokens are conditionally independent given the rest)
Q2: Why was NSP criticized, and what did RoBERTa change?¶
Answer: NSP was found to be too easy — the model could often distinguish real vs. random sentence pairs just from topic mismatch rather than learning real discourse structure. RoBERTa showed that removing NSP and training with full-document sequences improved performance. The key changes in RoBERTa: (1) drop NSP, (2) dynamic masking (re-sample masks each epoch), (3) larger batches, (4) more data, (5) longer training.
Q3: How would you use a BERT-style model in a RAG pipeline?¶
Answer: Two main roles:
- Retriever: Use BERT as a bi-encoder to embed both queries and documents into the same vector space. Find relevant documents via nearest-neighbor search (e.g., FAISS).
- Reranker: Use BERT as a cross-encoder — concatenate query + document as [CLS] query [SEP] document [SEP] and predict a relevance score. Cross-encoders are more accurate than bi-encoders but can't be pre-computed.
Q4: Why does BERT use three masking strategies (80/10/10) instead of always using [MASK]?¶
Answer: If BERT always replaced selected tokens with [MASK], there would be a train-test mismatch: during fine-tuning, the model never sees [MASK] tokens, but during pre-training it learned to attend heavily to them. The 10% random replacement forces the model to maintain good representations for all tokens (not just masked ones). The 10% keep-original ensures the model doesn't learn "if a token is not [MASK], it must be correct."
Q5: Can BERT generate text? Why or why not?¶
Answer: BERT is not designed for generation. MLM assumes masked positions are conditionally independent — it predicts each mask separately. For generation, you need the probability of a full sequence, which requires either left-to-right factorization (GPT) or iterative refinement. Some work (e.g., mask-predict for machine translation) uses BERT-like models for generation by iteratively masking and predicting, but this is inefficient compared to autoregressive models.
Connections to Other Papers¶
- Transformer → BERT uses the encoder stack
- GPT-2 → Contrasting approach: decoder-only, causal LM
- RoBERTa → "BERT done right" — better training recipe
- ELECTRA → Replaced token detection instead of MLM (more efficient)
- XLNet → Permutation LM to get bidirectional context without masking
Key Takeaways for Quick Review¶
| Concept | Remember |
|---|---|
| Architecture | Transformer encoder (bidirectional attention) |
| Pre-training | MLM (15% mask → predict) + NSP (dropped by RoBERTa) |
| Masking strategy | 80% [MASK], 10% random, 10% keep |
| Fine-tuning | Add thin task head on top of pre-trained encoder |
| Best for | Classification, NER, retrieval, reranking |
| Not for | Text generation (use GPT for that) |