Chinchilla: Training Compute-Optimal Large Language Models¶
Authors: Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, and 8 more
Year: 2022 | Venue: arXiv
Link: arXiv:2203.15556
TL;DR¶
Chinchilla revisits scaling laws and demonstrates that most large language models are undertrained: for a fixed compute budget, smaller models trained on more data consistently outperform larger models trained on less data. The paper fits empirical laws relating loss to parameters \(N\) and tokens \(D\), and derives compute-optimal allocation ratios — showing that parameters and data should scale roughly equally.
Why This Paper Matters¶
Chinchilla fundamentally changed how the industry thinks about training budgets:
- Before Chinchilla: "Make the model as large as possible" (GPT-3, PaLM)
- After Chinchilla: "Balance model size with data" (LLaMA, Mistral)
- The 70B model trained on 1.4T tokens (Chinchilla) outperformed the 280B Gopher trained on 300B tokens
- Directly shaped open-source LLM training (LLaMA used Chinchilla-optimal ratios)
- "Chinchilla-optimal" became a standard term in model training discussions
Key Concepts Explained Simply¶
The Core Insight¶
If you have a fixed compute budget (say, $10M worth of GPU time), should you: - (A) Train a huge 500B model on 100B tokens, or - (B) Train a smaller 70B model on 1.4T tokens?
Chinchilla showed that (B) is better. The larger model is undertrained — it hasn't seen enough data to fully utilize its capacity. The smaller model, having seen more diverse data, generalizes better.
The Rule of Thumb¶
For compute-optimal training, tokens should scale linearly with parameters:
- ~20 tokens per parameter (Chinchilla's finding)
- GPT-3 (175B params, 300B tokens) → only ~1.7 tokens per parameter → severely undertrained
- Chinchilla (70B params, 1.4T tokens) → 20 tokens per parameter → compute-optimal
When Chinchilla Breaks Down¶
The scaling laws assume: - Fresh, unique data for each token (no repeated epochs) - Data quality is consistent as you scale - Task distribution at evaluation matches training - In practice, running out of high-quality data means you may need to choose larger models
The Math — Explained Step by Step¶
Parametric Loss Function¶
Breaking it down:
- \(A / N^{\alpha}\): Loss from insufficient model capacity — decreases with more parameters
- \(B / D^{\beta}\): Loss from insufficient data — decreases with more training tokens
- \(L_{\infty}\): Irreducible loss — the entropy of natural language, can't be reduced by any model
- \(\alpha \approx 0.34, \beta \approx 0.28\): Fitted exponents — both show diminishing returns
Compute Constraint¶
For a Transformer with \(N\) parameters trained on \(D\) tokens:
This is the approximate FLOPs budget. Given fixed \(C\), we want to find the \(N^*\) and \(D^*\) that minimize \(L(N, D)\) subject to \(C = 6ND\).
Optimal Allocation¶
Taking the Lagrangian and solving:
where \(a \approx 0.5\) and \(b \approx 0.5\). This means: - Parameters and tokens should scale equally with compute - Double your budget → increase both N and D by ~\(\sqrt{2}\)
The 20× Rule¶
From the empirical fits:
For compute-optimal training, you need roughly 20 tokens per parameter.
Python Implementation¶
import numpy as np
from scipy.optimize import minimize_scalar
def chinchilla_loss(N, D, A=406.4, alpha=0.34, B=410.7, beta=0.28, L_inf=1.69):
"""
Parametric loss model from Chinchilla.
N: number of parameters
D: number of training tokens
"""
return A / (N ** alpha) + B / (D ** beta) + L_inf
def compute_flops(N, D):
"""Approximate training FLOPs: C ≈ 6ND."""
return 6 * N * D
def optimal_allocation(C, A=406.4, alpha=0.34, B=410.7, beta=0.28):
"""
Find the optimal N and D for a given compute budget C.
Minimizes L(N, D) subject to 6*N*D = C.
"""
def loss_for_N(log_N):
N = np.exp(log_N)
D = C / (6 * N)
if D <= 0:
return 1e10
return chinchilla_loss(N, D, A, alpha, B, beta)
# Search over reasonable model sizes
min_N = 1e6 # 1M params
max_N = C / 6 # Maximum N (1 token)
result = minimize_scalar(
loss_for_N,
bounds=(np.log(min_N), np.log(max_N)),
method='bounded'
)
N_opt = np.exp(result.x)
D_opt = C / (6 * N_opt)
return N_opt, D_opt, result.fun
def tokens_per_parameter(N, D):
"""Compute the data-to-parameter ratio."""
return D / N
def is_chinchilla_optimal(N, D, target_ratio=20, tolerance=0.5):
"""Check if a model is approximately Chinchilla-optimal."""
ratio = D / N
return abs(ratio - target_ratio) / target_ratio < tolerance
def compare_models():
"""Compare real models against Chinchilla-optimal ratios."""
models = [
("GPT-3", 175e9, 300e9),
("Gopher", 280e9, 300e9),
("Chinchilla", 70e9, 1.4e12),
("LLaMA-7B", 7e9, 1.0e12),
("LLaMA-13B", 13e9, 1.0e12),
("LLaMA-65B", 65e9, 1.4e12),
("Mistral-7B", 7e9, 8e12), # estimated
]
print(f"{'Model':<15} {'Params':>10} {'Tokens':>10} {'Tok/Param':>10} {'Optimal?':>10}")
print("-" * 60)
for name, N, D in models:
ratio = tokens_per_parameter(N, D)
optimal = is_chinchilla_optimal(N, D)
flops = compute_flops(N, D)
print(f"{name:<15} {N/1e9:>8.0f}B {D/1e9:>8.0f}B {ratio:>10.1f} {'✓' if optimal else '✗':>10}")
def scaling_curve(compute_budgets):
"""Show optimal N and D for various compute budgets."""
print(f"\n{'Compute (FLOPs)':>18} {'Optimal N':>12} {'Optimal D':>12} {'Tok/Param':>10} {'Loss':>8}")
print("-" * 65)
for C in compute_budgets:
N_opt, D_opt, loss = optimal_allocation(C)
ratio = D_opt / N_opt
print(f"{C:>18.2e} {N_opt/1e9:>10.1f}B {D_opt/1e9:>10.1f}B {ratio:>10.1f} {loss:>8.3f}")
def data_constrained_analysis():
"""What happens when you run out of unique data?"""
print("\n--- Data-Constrained Scenario ---")
N = 70e9
unique_tokens = 1e12
for epochs in [1, 2, 4, 8]:
effective_D = unique_tokens * epochs
loss = chinchilla_loss(N, effective_D)
# Diminishing returns from repeated data
penalty = 1 + 0.05 * np.log(epochs) if epochs > 1 else 0
adjusted_loss = loss + penalty
print(f" Epochs: {epochs}, Effective D: {effective_D/1e12:.0f}T, "
f"Loss: {loss:.3f}, Adjusted (repeat penalty): {adjusted_loss:.3f}")
# --- Demo ---
if __name__ == "__main__":
compare_models()
scaling_curve([1e18, 1e19, 1e20, 1e21, 1e22, 1e23, 1e24])
data_constrained_analysis()
# Visualize the trade-off
print("\n--- Fixed Compute Trade-off (C = 6e21 FLOPs) ---")
C = 6e21
print(f"{'N (params)':>15} {'D (tokens)':>15} {'Loss':>8}")
print("-" * 40)
for N in [1e9, 5e9, 10e9, 50e9, 100e9, 500e9]:
D = C / (6 * N)
if D > 0:
loss = chinchilla_loss(N, D)
print(f"{N/1e9:>13.1f}B {D/1e9:>13.1f}B {loss:>8.3f}")
Interview Importance¶
Chinchilla is a must-know paper. "Chinchilla-optimal" comes up constantly when discussing model training decisions.
Difficulty Level: ⭐⭐⭐ (Medium)¶
Interview Questions & Answers¶
Q1: State the Chinchilla insight in one sentence.¶
Answer: For a fixed compute budget, you should train a smaller model on more data rather than a larger model on less data, because parameters and training tokens should scale roughly equally with compute (approximately 20 tokens per parameter for optimal training).
Q2: What assumptions break scaling laws?¶
Answer: 1. Data saturation: When you run out of unique, high-quality data and must repeat epochs, the effective value of additional tokens diminishes 2. Repeated epochs: Scaling laws assume each token is seen once; repeated data provides diminishing returns 3. Data quality variation: Scaling laws assume uniform data quality; in practice, scraping more data means lower quality 4. Task-specific evaluation: Scaling laws predict average loss, but specific task performance may not follow smooth trends 5. Distribution shift: Training on web text may not improve performance on specialized domains proportionally
Q3: How would you decide whether to increase data vs. model width under a fixed budget?¶
Answer: 1. Compute the current tokens-per-parameter ratio 2. If ratio < 20: You're undertrained → invest in more data 3. If ratio > 20: You might benefit from a larger model 4. Check data availability: If you've exhausted high-quality data, a larger model on the same data may still help (but with diminishing returns) 5. Consider inference cost: A smaller, well-trained model is cheaper to serve than a larger undertrained one 6. Run small-scale experiments: Train models at 1/100th scale with different N/D ratios and extrapolate
Q4: How does Chinchilla relate to LLaMA's training strategy?¶
Answer: LLaMA directly applied Chinchilla's insights: - LLaMA-7B was trained on 1T tokens (~143 tokens per parameter — far beyond Chinchilla-optimal) - LLaMA-65B was trained on 1.4T tokens (~21.5 tokens per parameter — close to optimal) - The result: LLaMA-13B matched GPT-3 (175B) on many benchmarks because it was properly trained despite being 13× smaller - LLaMA showed you can "overtrain" small models (go beyond 20 tokens/param) if you want smaller, cheaper-to-serve models
Connections to Other Papers¶
- GPT-3 → Chinchilla showed GPT-3 was undertrained
- PaLM → Also likely undertrained (540B params, ~780B tokens)
- LLaMA → Applied Chinchilla insights for efficient open models
- Mistral → Pushed even further: 7B model trained on massive data
Key Takeaways for Quick Review¶
| Concept | Remember |
|---|---|
| Core insight | Smaller models + more data > larger models + less data |
| Optimal ratio | ~20 tokens per parameter |
| Loss formula | \(L(N,D) = A/N^α + B/D^β + L_∞\) |
| FLOPs rule | C ≈ 6ND |
| Equal scaling | N and D should scale equally with compute |
| Chinchilla model | 70B params on 1.4T tokens beat 280B Gopher |
| Practical impact | Shaped LLaMA, Mistral, and open-model training |