Skip to content

T5: Text-to-Text Transfer Transformer

Authors: Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou
Year: 2019 (JMLR 2020)  |  Venue: JMLR
Link: arXiv:1910.10683


TL;DR

T5 unifies every NLP task as a text-to-text problem: both inputs and outputs are strings. Translation becomes "translate English to German: [input]" → "[output]". Classification becomes "classify: [input]" → "positive". The same encoder-decoder Transformer is pre-trained with a span corruption objective (randomly drop spans of tokens, predict the missing pieces) and fine-tuned on downstream tasks using the same architecture and loss.


Why This Paper Matters

T5 introduced the text-to-text framework that shaped how we think about multitask learning and instruction following:

  1. Every task uses the same model, same loss function, same decoding
  2. Task identity is encoded in a text prefix rather than task-specific heads
  3. This idea directly influenced FLAN (instruction tuning) and chat models
  4. The paper is also a massive empirical survey — it systematically compares architectures, objectives, data, and training strategies

Key Concepts Explained Simply

Text-to-Text Framework

Instead of designing different output layers for different tasks: - Classification: Input → "sentiment: The movie was great" → Output → "positive" - Translation: Input → "translate English to German: Hello" → Output → "Hallo" - Summarization: Input → "summarize: [long text]" → Output → "[short summary]" - QA: Input → "question: Who wrote Hamlet? context: ..." → Output → "Shakespeare"

The model always does the same thing: takes a text input, produces a text output.

Span Corruption (Pre-training Objective)

Unlike BERT's single-token masking, T5 masks contiguous spans of tokens:

  • Original: "The quick brown fox jumps over the lazy dog"
  • Corrupted: "The \<X> fox \<Y> dog"
  • Target: "\<X> quick brown \<Y> jumps over the lazy"

Where <X>, <Y> are sentinel tokens. The model learns to predict entire spans, which is a stronger training signal than single-token prediction because it requires understanding phrase-level structure.

Encoder-Decoder Architecture

Unlike GPT (decoder-only) or BERT (encoder-only), T5 uses the full encoder-decoder: - Encoder: Processes the input with bidirectional attention (sees full context) - Decoder: Generates the output autoregressively with cross-attention to the encoder

This is naturally suited for seq2seq tasks where input and output can have different lengths.


The Math — Explained Step by Step

Span Corruption Objective

Given a sequence \(\mathbf{x}\), randomly sample spans \(\mathbf{s} = (s_1, s_2, \ldots)\) to remove, producing corrupted input \(\tilde{\mathbf{x}}\). The model predicts each missing token:

\[ \mathcal{L} = -\sum_{j} \log P_\theta(s_j \mid \tilde{\mathbf{x}}, \mathbf{s}_{<j}) \]

Breaking it down:

  1. \(\tilde{\mathbf{x}}\): Input with spans replaced by sentinel tokens (<extra_id_0>, <extra_id_1>, etc.)
  2. \(\mathbf{s}_{<j}\): Previously predicted tokens in the target (for autoregressive decoding)
  3. The encoder processes \(\tilde{\mathbf{x}}\) bidirectionally; the decoder generates the target tokens

Why Span Corruption > Single-Token MLM

  • MLM (BERT): Each masked token is predicted independently → weaker signal for phrase-level dependencies
  • Span corruption (T5): The decoder must predict the entire span sequentially, learning to reconstruct coherent phrases
  • Typical span length: mean of 3 tokens, with 15% of tokens corrupted total

Text-to-Text Loss

For any downstream task, fine-tuning uses the same seq2seq loss:

\[ \mathcal{L}_{\text{task}} = -\sum_{t=1}^{|y|} \log P_\theta(y_t \mid \text{prefix}(x), y_{<t}) \]

where prefix(x) is the task-prefixed input (e.g., "translate English to German: [x]").


Python Implementation

import numpy as np


def stable_softmax(logits):
    z = logits - np.max(logits, axis=-1, keepdims=True)
    e = np.exp(z)
    return e / np.sum(e, axis=-1, keepdims=True)


def span_corrupt(tokens, mask_rate=0.15, mean_span_length=3, sentinel_start=32000):
    """
    T5-style span corruption.
    Replaces contiguous spans with sentinel tokens.
    Returns: (corrupted_input, target_sequence)
    """
    n = len(tokens)
    num_masked = max(1, int(n * mask_rate))

    # Determine spans
    mask = [False] * n
    masked_count = 0
    while masked_count < num_masked:
        span_len = max(1, int(np.random.geometric(1.0 / mean_span_length)))
        span_len = min(span_len, num_masked - masked_count)
        start = np.random.randint(0, n - span_len + 1)

        already = sum(mask[start:start + span_len])
        if already > 0:
            continue

        for i in range(start, start + span_len):
            mask[i] = True
        masked_count += span_len

    # Build corrupted input and target
    corrupted, target = [], []
    sentinel_id = sentinel_start
    in_span = False

    for i in range(n):
        if mask[i]:
            if not in_span:
                corrupted.append(f"<extra_id_{sentinel_id - sentinel_start}>")
                target.append(f"<extra_id_{sentinel_id - sentinel_start}>")
                sentinel_id += 1
                in_span = True
            target.append(tokens[i])
        else:
            in_span = False
            corrupted.append(tokens[i])

    return corrupted, target


def text_to_text_format(task_name, input_text):
    """Format any task as text-to-text with a task prefix."""
    return f"{task_name}: {input_text}"


def encoder_forward(x, W_q, W_k, W_v, W_o):
    """Simplified bidirectional encoder (full attention, no masking)."""
    d_k = x.shape[-1]
    Q, K, V = x @ W_q, x @ W_k, x @ W_v

    scores = (Q @ K.T) / np.sqrt(d_k)
    attn = stable_softmax(scores)
    return (attn @ V) @ W_o


def decoder_with_cross_attention(y, enc_out, W_self_q, W_self_k, W_self_v,
                                  W_cross_q, W_cross_k, W_cross_v, W_o):
    """Simplified decoder step with causal self-attention and cross-attention."""
    d_k = y.shape[-1]
    seq_len = y.shape[0]

    # Causal self-attention
    Q, K, V = y @ W_self_q, y @ W_self_k, y @ W_self_v
    scores = (Q @ K.T) / np.sqrt(d_k)
    causal_mask = np.triu(np.full((seq_len, seq_len), -1e9), k=1)
    scores = scores + causal_mask
    self_attn = stable_softmax(scores) @ V

    # Cross-attention: queries from decoder, keys/values from encoder
    Q_cross = self_attn @ W_cross_q
    K_cross = enc_out @ W_cross_k
    V_cross = enc_out @ W_cross_v
    cross_scores = (Q_cross @ K_cross.T) / np.sqrt(d_k)
    cross_attn = stable_softmax(cross_scores) @ V_cross

    return cross_attn @ W_o


def seq2seq_loss(logits, target_ids):
    """Cross-entropy loss for sequence-to-sequence output."""
    probs = stable_softmax(logits)
    loss = 0.0
    for t, tid in enumerate(target_ids):
        loss -= np.log(probs[t, tid] + 1e-12)
    return loss / len(target_ids)


# --- Demo ---
if __name__ == "__main__":
    # Span corruption demo
    sentence = "The quick brown fox jumps over the lazy dog".split()
    np.random.seed(42)
    corrupted, target = span_corrupt(sentence)
    print("Original:", " ".join(sentence))
    print("Corrupted:", " ".join(corrupted))
    print("Target:", " ".join(target))

    # Text-to-text formatting
    tasks = [
        ("translate English to German", "The house is wonderful."),
        ("summarize", "Long article about climate change effects on agriculture..."),
        ("stsb sentence1", "A man is eating food. sentence2: A man is eating pasta."),
        ("cola sentence", "The course is jumping well."),
    ]
    print("\nText-to-text formatting:")
    for task, inp in tasks:
        print(f"  {text_to_text_format(task, inp)}")

Interview Importance

T5 is asked about when contrasting encoder-decoder vs. decoder-only architectures and when discussing multitask learning and instruction formatting.

Difficulty Level: ⭐⭐ (Medium)


Interview Questions & Answers

Q1: Compare encoder-decoder (T5) vs. decoder-only (GPT) for generation tasks.

Answer: - Encoder-decoder: Input is processed bidirectionally by the encoder, then the decoder generates output with cross-attention. Better for fixed-input tasks (translation, summarization) where the input is fully known. - Decoder-only: Input and output are concatenated in a single sequence with causal masking. More flexible for open-ended generation and simpler to scale (one stack instead of two). - In practice, decoder-only models dominate at large scale because they're simpler and the bidirectional encoding advantage diminishes with enough scale and data.

Q2: How does span corruption differ from BERT's MLM?

Answer: - BERT MLM: Masks individual tokens (15%), predicts each independently - T5 span corruption: Masks contiguous spans (mean length 3), decoder predicts them sequentially

Span corruption provides a stronger training signal because: 1. The decoder must predict coherent multi-token phrases (not just isolated words) 2. Each sentinel represents an entire span, so the model learns to reconstruct context 3. The autoregressive decoder naturally captures dependencies between tokens within a span

Q3: Why might task prefixes help multitask training without separate model heads?

Answer: Task prefixes serve as soft routing signals. The model learns to: 1. Parse the prefix to understand which task is being requested 2. Activate different internal "circuits" or attention patterns based on the prefix 3. Generate output in the appropriate format

This is more flexible than separate heads because: - New tasks can be added without architectural changes - The model can share representations across related tasks - At inference time, any task can be invoked just by changing the text prefix

Q4: What did the T5 paper find when comparing different architectures and objectives?

Answer: The paper systematically compared: - Architectures: Encoder-decoder slightly outperformed decoder-only and prefix-LM on seq2seq tasks - Objectives: Span corruption outperformed both MLM and language modeling for pre-training - Data: More diverse, larger pre-training data (C4 corpus) improved results - Corruption rate: 15% corruption with mean span length 3 worked best - Multitask: Pre-training + fine-tuning beat multitask training alone

Q5: How does T5 relate to FLAN and instruction tuning?

Answer: T5's text-to-text framework directly enables FLAN-style instruction tuning. FLAN takes T5's idea of task prefixes and scales it to thousands of tasks with diverse instruction phrasings. The key progression: 1. T5: "translate English to German: [input]" (fixed prefixes) 2. FLAN: "Please translate the following sentence into German: [input]" (natural language instructions) 3. Chat models: Free-form instruction following with conversational format


Connections to Other Papers

  • Transformer → T5 uses the full encoder-decoder architecture
  • BERT → Span corruption extends MLM to contiguous spans
  • GPT-2/3 → Contrasting decoder-only approach
  • FLAN → Instruction-tunes T5 on a large task mixture
  • LLaMA → Decoder-only models eventually dominated, but T5 influenced the instruction format

Key Takeaways for Quick Review

Concept Remember
Framework Every task is text-to-text (same model, same loss)
Architecture Encoder-decoder Transformer
Pre-training Span corruption (mask spans, predict with decoder)
Task format Task prefix → "translate English to German: [input]"
Pre-training data C4 (Colossal Clean Crawled Corpus)
Key finding Training recipe matters as much as architecture
Legacy Directly influenced FLAN and instruction-tuning paradigm