How much model, how much data, for your compute?

Think first
You have $10 million to train a language model. You can spend it on a bigger model, on more training data, or on longer training. If you double your budget, should you double the model size, double the data, or something else entirely?

For years the industry believed bigger models were the answer. OpenAI's 2020 scaling laws (Kaplan et al.) said: 10x more compute should buy ~5.5x more parameters and only ~1.8x more data. This reasoning produced GPT-3 (175B parameters trained on ~300B tokens).

In 2022 DeepMind's Chinchilla paper (Hoffmann et al.) showed this was wrong. By training 400+ models at different (size, data) trade-offs, they found compute-optimal training requires scaling parameters and data roughly equally. 10x more compute -> ~sqrt(10) ~ 3.16x more parameters AND ~3.16x more tokens.

Key Insight

Chinchilla (70B params, 1.4T tokens) matched the much larger Gopher (280B params, 300B tokens) while using the same total compute. GPT-3 was dramatically undertrained -- it had enough capacity for much more data than it ever saw.

The Chinchilla Law: C ~ 6 N D

The core scaling law says that loss is a power-law function of model size N and dataset size D. If you want the best possible loss for a fixed compute budget C, you should pick N and D so that both contribute equally to the loss floor.

Compute (FLOPs) ~ 6 * N * D
    where N = parameters, D = training tokens

Chinchilla recipe:  D ~ 20 * N     (20 tokens per parameter is the sweet spot)

This is why every modern lab now quotes its models in "tokens per parameter". GPT-3 had ~1.7 tokens/param (way undertrained). Chinchilla had 20 (compute-optimal). LLaMA 2 7B was trained on ~285 tokens/param (heavily overtrained on purpose -- see below).

Common misconception

"Chinchilla-optimal means you should always train at 20 tokens per parameter." No. Chinchilla-optimal minimizes training cost for a given loss. If inference cost matters (it usually does in production), you should deliberately overtrain a smaller model. Meta's LLaMA series pushes this to the extreme for deployability.

Interactive: Chinchilla Compute Budget Calculator

Interactive: Chinchilla Calculator

Choose a compute budget in FLOPs. The calculator returns the Chinchilla-optimal model size and number of training tokens.

1.0e22
Total FLOPs
7.1B
Optimal params (N)
141B
Optimal tokens (D)
20
tokens / param

Post-Chinchilla: The Overtraining Era

Chinchilla-optimal minimizes training cost. But in production, inference cost dominates -- you pay the model's running cost for every user query, potentially billions of times. Serving a 70B model at 20 tokens/param is vastly more expensive than serving a 7B model at 285 tokens/param, even if the 70B has slightly lower training loss.

Meta's LLaMA family embraced this trade-off. LLaMA 3 8B saw ~1.875 trillion tokens per billion parameters -- almost 100x the Chinchilla recipe. Training cost went up, but inference cost dropped through the floor.

Real-world rule of thumb

If you'll serve this model a million times, a 10x increase in training cost is worth it if it yields even a 10% reduction in inference cost. That's why small "overtrained" models dominate production today.

The Optimizer Lineage: SGD -> AdamW

Training a 100B-parameter model doesn't work with vanilla gradient descent. Modern LLMs all use variants of AdamW. Here's the lineage:

Interactive: AdamW State Visualizer

Simulate a few optimization steps and watch the first moment (m) and second moment (v) evolve.

# Standard AdamW hyperparameters for LLM training
beta1, beta2  = 0.9, 0.999    # exponential decay rates
epsilon       = 1e-8          # numerical stability
weight_decay  = 0.1           # applied directly to weights
learning_rate = 3e-4 with cosine decay after warmup
gradient_clip = 1.0           # global norm

Learning Rate Schedules Revisited

Module 7 introduced warmup + cosine decay. Here we go deeper. The formulas:

# Linear warmup, steps 0..warmup_steps
lr(t) = peak_lr * (t / warmup_steps)

# Cosine decay, steps warmup_steps..total_steps
progress = (t - warmup_steps) / (total_steps - warmup_steps)
lr(t) = min_lr + 0.5 * (peak_lr - min_lr) * (1 + cos(pi * progress))
Interactive: LR Schedule Plotter (Warmup + Cosine)

LR rules of thumb for LLMs:

Bigger models want smaller LRs because their loss landscape is narrower and large steps destabilize training.

Gradient Clipping: The Safety Net

Occasionally a batch contains a weird example that produces an enormous gradient. Without intervention, one bad batch can blow up the entire run. Gradient clipping rescales the gradient whenever its norm exceeds a threshold (typically 1.0):

g_norm = sqrt(sum of squares of all gradient elements)
if g_norm > max_norm:
    g = g * (max_norm / g_norm)   # preserves direction, caps magnitude
Interactive: Gradient Clipping Demo
0.80
Raw norm
0.80
After clip (max=1.0)
OK
Status

Precision: FP32, FP16, BF16

Modern training uses mixed precision: most math in low precision (16 bit) for speed, but critical values (master weights, loss accumulation) in FP32.

A 70B parameter model in BF16 takes 140 GB just for weights. Add gradients (140 GB) and AdamW optimizer states (560 GB in FP32 for momentum + variance) and you are already at ~840 GB before you include activations or a training batch. This is why parallelism across many GPUs is mandatory (covered in Module 14).

Batch Size and Training Dynamics

LLM batch sizes are measured in tokens (not sequences). Typical: 2-4 million tokens per update.

What training runs actually look like

OPT-175B's training logbook (publicly released by Meta) documents 2 months of training with dozens of crashes, loss spikes, hardware failures, and manual rollbacks to previous checkpoints. BLOOM-176B had similar experiences. Frontier training is less "run a script" and more "firefight in slow motion."

Check Your Understanding

1. What does the Chinchilla paper say about compute-optimal training?
Correct: Model size and training tokens should scale roughly equally (about 20 tokens per parameter)
2. Why do modern labs overtrain smaller models (e.g., LLaMA 3 at 1875 tokens/param)?
Correct: Because inference cost dominates production cost, and smaller well-trained models are vastly cheaper to serve
3. What does AdamW fix compared to vanilla Adam?
Correct: Decouples weight decay from the adaptive learning rate
4. Why use BF16 instead of FP16 for LLM training?
Correct: BF16 has the same exponent range as FP32, so no underflow and no loss scaling needed
5. What is gradient clipping and why use it?
Correct: Rescaling the gradient whenever its global norm exceeds a threshold, to prevent a single bad batch from blowing up the run

Teach It Back

Explain to a friend: What do scaling laws tell us about how to allocate compute between model size and data? Why does everyone use AdamW with warmup + cosine decay and gradient clipping? And why do production models like LLaMA 3 deliberately overtrain well past Chinchilla-optimal?

An AI tutor will compare your explanation against the course material.

Evaluating...

Flashcards (click to flip)

What is the Chinchilla scaling law in one sentence?
Click to reveal
For compute-optimal training, model parameters N and training tokens D should scale roughly equally, with about D ~ 20 N. 10x more compute buys ~3.16x more params AND ~3.16x more data.
Why is GPT-3 considered "undertrained"?
Click to reveal
It saw only ~1.7 tokens per parameter (300B tokens on 175B params). Chinchilla would recommend ~3.5T tokens for a 175B model. GPT-3 had far more capacity than its training data could fill.
Chinchilla-optimal vs overtrained: which is actually better?
Click to reveal
Depends on inference volume. Chinchilla minimizes training cost. But if you serve the model a billion times, a smaller overtrained model wins on total cost. That is why LLaMA 3 8B was trained on ~15T tokens.
Standard AdamW hyperparameters for LLMs?
Click to reveal
beta1=0.9, beta2=0.999 (momentum and RMSProp decay rates), eps=1e-8, weight_decay=0.1, LR=3e-4 with warmup+cosine, gradient clip=1.0, BF16 precision.
FP32 vs FP16 vs BF16?
Click to reveal
FP32: 32 bits, safe but slow. FP16: 16 bits, narrow range, needs loss scaling. BF16: 16 bits with FP32-sized exponent -- no underflow, no loss scaling, now standard.
What is gradient clipping?
Click to reveal
Rescale the gradient vector whenever its global L2 norm exceeds a threshold (typically 1.0). Preserves direction, caps magnitude. Protects against a single bad batch destabilizing the run.