T5: Text-to-Text Transfer Transformer¶
Authors: Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou
Year: 2019 (JMLR 2020) | Venue: JMLR
Link: arXiv:1910.10683
TL;DR¶
T5 unifies every NLP task as a text-to-text problem: both inputs and outputs are strings. Translation becomes "translate English to German: [input]" → "[output]". Classification becomes "classify: [input]" → "positive". The same encoder-decoder Transformer is pre-trained with a span corruption objective (randomly drop spans of tokens, predict the missing pieces) and fine-tuned on downstream tasks using the same architecture and loss.
Why This Paper Matters¶
T5 introduced the text-to-text framework that shaped how we think about multitask learning and instruction following:
- Every task uses the same model, same loss function, same decoding
- Task identity is encoded in a text prefix rather than task-specific heads
- This idea directly influenced FLAN (instruction tuning) and chat models
- The paper is also a massive empirical survey — it systematically compares architectures, objectives, data, and training strategies
Key Concepts Explained Simply¶
Text-to-Text Framework¶
Instead of designing different output layers for different tasks: - Classification: Input → "sentiment: The movie was great" → Output → "positive" - Translation: Input → "translate English to German: Hello" → Output → "Hallo" - Summarization: Input → "summarize: [long text]" → Output → "[short summary]" - QA: Input → "question: Who wrote Hamlet? context: ..." → Output → "Shakespeare"
The model always does the same thing: takes a text input, produces a text output.
Span Corruption (Pre-training Objective)¶
Unlike BERT's single-token masking, T5 masks contiguous spans of tokens:
- Original: "The quick brown fox jumps over the lazy dog"
- Corrupted: "The \<X> fox \<Y> dog"
- Target: "\<X> quick brown \<Y> jumps over the lazy"
Where <X>, <Y> are sentinel tokens. The model learns to predict entire spans, which is a stronger training signal than single-token prediction because it requires understanding phrase-level structure.
Encoder-Decoder Architecture¶
Unlike GPT (decoder-only) or BERT (encoder-only), T5 uses the full encoder-decoder: - Encoder: Processes the input with bidirectional attention (sees full context) - Decoder: Generates the output autoregressively with cross-attention to the encoder
This is naturally suited for seq2seq tasks where input and output can have different lengths.
The Math — Explained Step by Step¶
Span Corruption Objective¶
Given a sequence \(\mathbf{x}\), randomly sample spans \(\mathbf{s} = (s_1, s_2, \ldots)\) to remove, producing corrupted input \(\tilde{\mathbf{x}}\). The model predicts each missing token:
Breaking it down:
- \(\tilde{\mathbf{x}}\): Input with spans replaced by sentinel tokens (
<extra_id_0>,<extra_id_1>, etc.) - \(\mathbf{s}_{<j}\): Previously predicted tokens in the target (for autoregressive decoding)
- The encoder processes \(\tilde{\mathbf{x}}\) bidirectionally; the decoder generates the target tokens
Why Span Corruption > Single-Token MLM¶
- MLM (BERT): Each masked token is predicted independently → weaker signal for phrase-level dependencies
- Span corruption (T5): The decoder must predict the entire span sequentially, learning to reconstruct coherent phrases
- Typical span length: mean of 3 tokens, with 15% of tokens corrupted total
Text-to-Text Loss¶
For any downstream task, fine-tuning uses the same seq2seq loss:
where prefix(x) is the task-prefixed input (e.g., "translate English to German: [x]").
Python Implementation¶
import numpy as np
def stable_softmax(logits):
z = logits - np.max(logits, axis=-1, keepdims=True)
e = np.exp(z)
return e / np.sum(e, axis=-1, keepdims=True)
def span_corrupt(tokens, mask_rate=0.15, mean_span_length=3, sentinel_start=32000):
"""
T5-style span corruption.
Replaces contiguous spans with sentinel tokens.
Returns: (corrupted_input, target_sequence)
"""
n = len(tokens)
num_masked = max(1, int(n * mask_rate))
# Determine spans
mask = [False] * n
masked_count = 0
while masked_count < num_masked:
span_len = max(1, int(np.random.geometric(1.0 / mean_span_length)))
span_len = min(span_len, num_masked - masked_count)
start = np.random.randint(0, n - span_len + 1)
already = sum(mask[start:start + span_len])
if already > 0:
continue
for i in range(start, start + span_len):
mask[i] = True
masked_count += span_len
# Build corrupted input and target
corrupted, target = [], []
sentinel_id = sentinel_start
in_span = False
for i in range(n):
if mask[i]:
if not in_span:
corrupted.append(f"<extra_id_{sentinel_id - sentinel_start}>")
target.append(f"<extra_id_{sentinel_id - sentinel_start}>")
sentinel_id += 1
in_span = True
target.append(tokens[i])
else:
in_span = False
corrupted.append(tokens[i])
return corrupted, target
def text_to_text_format(task_name, input_text):
"""Format any task as text-to-text with a task prefix."""
return f"{task_name}: {input_text}"
def encoder_forward(x, W_q, W_k, W_v, W_o):
"""Simplified bidirectional encoder (full attention, no masking)."""
d_k = x.shape[-1]
Q, K, V = x @ W_q, x @ W_k, x @ W_v
scores = (Q @ K.T) / np.sqrt(d_k)
attn = stable_softmax(scores)
return (attn @ V) @ W_o
def decoder_with_cross_attention(y, enc_out, W_self_q, W_self_k, W_self_v,
W_cross_q, W_cross_k, W_cross_v, W_o):
"""Simplified decoder step with causal self-attention and cross-attention."""
d_k = y.shape[-1]
seq_len = y.shape[0]
# Causal self-attention
Q, K, V = y @ W_self_q, y @ W_self_k, y @ W_self_v
scores = (Q @ K.T) / np.sqrt(d_k)
causal_mask = np.triu(np.full((seq_len, seq_len), -1e9), k=1)
scores = scores + causal_mask
self_attn = stable_softmax(scores) @ V
# Cross-attention: queries from decoder, keys/values from encoder
Q_cross = self_attn @ W_cross_q
K_cross = enc_out @ W_cross_k
V_cross = enc_out @ W_cross_v
cross_scores = (Q_cross @ K_cross.T) / np.sqrt(d_k)
cross_attn = stable_softmax(cross_scores) @ V_cross
return cross_attn @ W_o
def seq2seq_loss(logits, target_ids):
"""Cross-entropy loss for sequence-to-sequence output."""
probs = stable_softmax(logits)
loss = 0.0
for t, tid in enumerate(target_ids):
loss -= np.log(probs[t, tid] + 1e-12)
return loss / len(target_ids)
# --- Demo ---
if __name__ == "__main__":
# Span corruption demo
sentence = "The quick brown fox jumps over the lazy dog".split()
np.random.seed(42)
corrupted, target = span_corrupt(sentence)
print("Original:", " ".join(sentence))
print("Corrupted:", " ".join(corrupted))
print("Target:", " ".join(target))
# Text-to-text formatting
tasks = [
("translate English to German", "The house is wonderful."),
("summarize", "Long article about climate change effects on agriculture..."),
("stsb sentence1", "A man is eating food. sentence2: A man is eating pasta."),
("cola sentence", "The course is jumping well."),
]
print("\nText-to-text formatting:")
for task, inp in tasks:
print(f" {text_to_text_format(task, inp)}")
Interview Importance¶
T5 is asked about when contrasting encoder-decoder vs. decoder-only architectures and when discussing multitask learning and instruction formatting.
Difficulty Level: ⭐⭐ (Medium)¶
Interview Questions & Answers¶
Q1: Compare encoder-decoder (T5) vs. decoder-only (GPT) for generation tasks.¶
Answer: - Encoder-decoder: Input is processed bidirectionally by the encoder, then the decoder generates output with cross-attention. Better for fixed-input tasks (translation, summarization) where the input is fully known. - Decoder-only: Input and output are concatenated in a single sequence with causal masking. More flexible for open-ended generation and simpler to scale (one stack instead of two). - In practice, decoder-only models dominate at large scale because they're simpler and the bidirectional encoding advantage diminishes with enough scale and data.
Q2: How does span corruption differ from BERT's MLM?¶
Answer: - BERT MLM: Masks individual tokens (15%), predicts each independently - T5 span corruption: Masks contiguous spans (mean length 3), decoder predicts them sequentially
Span corruption provides a stronger training signal because: 1. The decoder must predict coherent multi-token phrases (not just isolated words) 2. Each sentinel represents an entire span, so the model learns to reconstruct context 3. The autoregressive decoder naturally captures dependencies between tokens within a span
Q3: Why might task prefixes help multitask training without separate model heads?¶
Answer: Task prefixes serve as soft routing signals. The model learns to: 1. Parse the prefix to understand which task is being requested 2. Activate different internal "circuits" or attention patterns based on the prefix 3. Generate output in the appropriate format
This is more flexible than separate heads because: - New tasks can be added without architectural changes - The model can share representations across related tasks - At inference time, any task can be invoked just by changing the text prefix
Q4: What did the T5 paper find when comparing different architectures and objectives?¶
Answer: The paper systematically compared: - Architectures: Encoder-decoder slightly outperformed decoder-only and prefix-LM on seq2seq tasks - Objectives: Span corruption outperformed both MLM and language modeling for pre-training - Data: More diverse, larger pre-training data (C4 corpus) improved results - Corruption rate: 15% corruption with mean span length 3 worked best - Multitask: Pre-training + fine-tuning beat multitask training alone
Q5: How does T5 relate to FLAN and instruction tuning?¶
Answer: T5's text-to-text framework directly enables FLAN-style instruction tuning. FLAN takes T5's idea of task prefixes and scales it to thousands of tasks with diverse instruction phrasings. The key progression: 1. T5: "translate English to German: [input]" (fixed prefixes) 2. FLAN: "Please translate the following sentence into German: [input]" (natural language instructions) 3. Chat models: Free-form instruction following with conversational format
Connections to Other Papers¶
- Transformer → T5 uses the full encoder-decoder architecture
- BERT → Span corruption extends MLM to contiguous spans
- GPT-2/3 → Contrasting decoder-only approach
- FLAN → Instruction-tunes T5 on a large task mixture
- LLaMA → Decoder-only models eventually dominated, but T5 influenced the instruction format
Key Takeaways for Quick Review¶
| Concept | Remember |
|---|---|
| Framework | Every task is text-to-text (same model, same loss) |
| Architecture | Encoder-decoder Transformer |
| Pre-training | Span corruption (mask spans, predict with decoder) |
| Task format | Task prefix → "translate English to German: [input]" |
| Pre-training data | C4 (Colossal Clean Crawled Corpus) |
| Key finding | Training recipe matters as much as architecture |
| Legacy | Directly influenced FLAN and instruction-tuning paradigm |