XLNet: Generalized Autoregressive Pretraining¶
Authors: Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
Year: 2019 | Venue: NeurIPS
Link: arXiv:1906.08237
TL;DR¶
XLNet combines the best of autoregressive modeling (GPT) and bidirectional context (BERT) using permutation language modeling: tokens are predicted in a random order, so each position can attend to any subset of other positions according to the permutation. A two-stream attention mechanism separates content and query representations to prevent information leakage.
Why This Paper Matters¶
XLNet clarifies the fundamental trade-off between causal and bidirectional factorizations. While it didn't become the default pre-training recipe (RoBERTa + scale won engineering mindshare), it introduced concepts that resurface in diffusion-based sequence models and non-autoregressive generation. It's a great paper for understanding why factorization order matters.
Key Concepts Explained Simply¶
The Problem with BERT's MLM¶
BERT masks tokens and predicts them independently. If you mask both "New" and "York" in "I visited [MASK] [MASK] last summer," BERT predicts "New" and "York" separately — it doesn't model the dependence between them. This is called the independence assumption, and it weakens the pre-training signal.
Permutation Language Modeling¶
Instead of always predicting left-to-right, XLNet considers all possible orderings of tokens. For a 4-token sequence, there are 4! = 24 permutations. In permutation \(\pi = (3, 1, 4, 2)\), token 3 is predicted first (no context), then token 1 (seeing token 3), then token 4 (seeing tokens 3, 1), etc.
This way, every token eventually gets to condition on every other token, but through a proper autoregressive factorization — no independence assumption.
Two-Stream Attention¶
A subtle problem: if token at position 3 is predicting itself, it shouldn't see its own content (that would be cheating). But it should know it's at position 3 (to use positional information). XLNet solves this with two representations:
- Content stream (h): Standard hidden state — knows what token is at this position
- Query stream (g): Knows the position and context from other tokens, but not the token at this position
The Math — Explained Step by Step¶
Permutation Language Modeling Objective¶
Breaking it down:
- \(\pi\): A random permutation of positions \(\{1, 2, \ldots, n\}\)
- \(\pi(t)\): The position predicted at step \(t\) in this permutation
- \(\mathbf{x}_{\pi(1:t-1)}\): All tokens at positions that come before position \(\pi(t)\) in this permutation
- The expectation is over random permutations — each training step samples a different order
Key insight: Different permutations expose different factorizations of the same joint distribution. Across many permutations, every token sees every possible context combination.
Two-Stream Equations¶
Content stream (standard self-attention, used for tokens already "revealed"):
Query stream (for the token being predicted — no self-content):
The query stream uses the position embedding but not the token embedding at position \(\pi(t)\).
Python Implementation¶
import numpy as np
from itertools import permutations
def stable_softmax(logits):
z = logits - np.max(logits, axis=-1, keepdims=True)
e = np.exp(z)
return e / np.sum(e, axis=-1, keepdims=True)
def sample_permutation(n):
"""Sample a random permutation of [0, 1, ..., n-1]."""
perm = list(range(n))
np.random.shuffle(perm)
return perm
def permutation_attention_mask(perm):
"""
Build attention mask for a given permutation.
Position perm[t] can attend to positions perm[0], ..., perm[t-1].
"""
n = len(perm)
mask = np.zeros((n, n))
for t, pos in enumerate(perm):
for prev in perm[:t]:
mask[pos][prev] = 1.0
# Content stream: can also see itself
mask[pos][pos] = 1.0
return mask
def query_stream_mask(perm):
"""
Query stream mask: same as content stream but WITHOUT self-attention.
Position perm[t] can attend to perm[0], ..., perm[t-1] but NOT itself.
"""
n = len(perm)
mask = np.zeros((n, n))
for t, pos in enumerate(perm):
for prev in perm[:t]:
mask[pos][prev] = 1.0
return mask
def two_stream_attention(h, g, W_q, W_k, W_v, content_mask, query_mask):
"""
h: content stream [n, d]
g: query stream [n, d]
Returns updated h and g.
"""
d_k = h.shape[-1]
# Content stream self-attention
Q_h = h @ W_q
K_h = h @ W_k
V_h = h @ W_v
scores_h = (Q_h @ K_h.T) / np.sqrt(d_k)
scores_h = np.where(content_mask == 0, -1e9, scores_h)
h_new = stable_softmax(scores_h) @ V_h
# Query stream attention (uses content stream K, V from visible positions)
Q_g = g @ W_q
scores_g = (Q_g @ K_h.T) / np.sqrt(d_k)
scores_g = np.where(query_mask == 0, -1e9, scores_g)
g_new = stable_softmax(scores_g) @ V_h
return h_new, g_new
def plm_log_prob(tokens, log_prob_fn, perm):
"""
Compute total log probability under a permutation ordering.
log_prob_fn(token, context_tokens) -> log probability
"""
total = 0.0
for t, idx in enumerate(perm):
context = [tokens[perm[j]] for j in range(t)]
total += log_prob_fn(tokens[idx], context)
return total
def compare_factorizations(tokens, log_prob_fn, n_perms=10):
"""Show how different permutations give different factorizations."""
results = []
for _ in range(n_perms):
perm = sample_permutation(len(tokens))
lp = plm_log_prob(tokens, log_prob_fn, perm)
results.append((perm, lp))
return results
# --- Demo ---
if __name__ == "__main__":
np.random.seed(42)
seq_len, d_model = 5, 16
perm = sample_permutation(seq_len)
print(f"Permutation: {perm}")
print(f"Prediction order: {[f'pos {p}' for p in perm]}")
c_mask = permutation_attention_mask(perm)
q_mask = query_stream_mask(perm)
print("\nContent mask (rows=query, cols=key):")
print(c_mask.astype(int))
print("\nQuery mask (no self-attention):")
print(q_mask.astype(int))
# Two-stream attention demo
h = np.random.randn(seq_len, d_model)
g = np.random.randn(seq_len, d_model)
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1
h_new, g_new = two_stream_attention(h, g, W_q, W_k, W_v, c_mask, q_mask)
print(f"\nContent stream output shape: {h_new.shape}")
print(f"Query stream output shape: {g_new.shape}")
Interview Importance¶
XLNet is less frequently asked directly but understanding it shows depth of knowledge about pre-training objectives and their trade-offs.
Difficulty Level: ⭐⭐⭐⭐ (Hard)¶
Interview Questions & Answers¶
Q1: Explain permutation LM vs. standard left-to-right training.¶
Answer: Standard autoregressive LM always predicts left-to-right: \(P(x_1) \cdot P(x_2|x_1) \cdot P(x_3|x_1,x_2) \cdots\). Permutation LM samples random orderings, so \(x_3\) might be predicted first (no context), then \(x_1\) (given \(x_3\)), etc. Both are valid factorizations of the same joint distribution. The benefit: across many permutations, every token gets to condition on every possible subset of other tokens — achieving bidirectional context within an autoregressive framework.
Q2: What problem does two-stream attention solve?¶
Answer: In standard attention, when predicting token at position \(t\), the model's query contains the token embedding itself — it would "see the answer." Two-stream attention separates: - Content stream: Contains the actual token embedding (used when providing context to other positions) - Query stream: Contains only the position embedding (used when this position is being predicted)
This prevents the model from trivially copying its own token while still allowing it to use positional information.
Q3: Why did RoBERTa + scale often win over XLNet in practice?¶
Answer: Several practical factors: 1. XLNet is more complex to implement (two-stream attention, permutation sampling) 2. Training is slower per step due to the additional stream 3. RoBERTa showed that simply fixing BERT's training recipe (more data, dynamic masking, longer training, no NSP) achieved most of the gains with less complexity 4. The independence assumption in MLM, which XLNet targets, matters less with large models and data 5. Engineering teams prefer simpler architectures that are easier to debug and optimize
Connections to Other Papers¶
- BERT → XLNet addresses BERT's masked token independence assumption
- GPT-2 → XLNet extends GPT-style autoregressive training with permutations
- RoBERTa → Simpler alternative that achieved competitive results
- Transformer-XL → XLNet builds on its segment-level recurrence for long contexts
Key Takeaways for Quick Review¶
| Concept | Remember |
|---|---|
| Core idea | Permutation language modeling — random factorization orders |
| Problem solved | BERT's independence assumption for masked tokens |
| Key mechanism | Two-stream attention (content vs. query) |
| Practical outcome | Competitive but complex; RoBERTa won mindshare |
| Interview value | Shows deep understanding of pre-training objectives |