Generate from a checkpoint¶
A DCP checkpoint + a tokenizer + a prompt →
scripts/generate.py
produces text. This page covers the CLI, the underlying
generate()
call, and the KV-cache mechanics so you can build your own generation
loops when the script isn’t enough.
Quick start¶
uv run python scripts/generate.py configs/train/7b.toml \
--checkpoint.load_path=checkpoints/7b/step_50000 \
--data.tokenizer_path=meta-llama/Llama-2-7b-hf \
--prompt "Once upon a time" \
--max_tokens 256 \
--temperature 0.8 \
--top_p 0.9
Model: 6738M params on cuda:0 (torch.bfloat16)
Loaded checkpoint: checkpoints/7b/step_50000
--- Prompt ---
Once upon a time
--- Generated (256 tokens) ---
, in a small village nestled between two mountains, there lived …
CLI reference¶
Flag |
Default |
Purpose |
|---|---|---|
|
— |
TOML config (required — used for model architecture) |
|
from config |
DCP directory to load (step_N or |
|
from config |
HuggingFace hub ID or local dir |
|
|
Input text |
|
|
Max new tokens to generate |
|
|
Sampling temperature (0 = greedy) |
|
|
Top-k filtering (0 = disabled) |
|
|
Nucleus sampling threshold (1.0 = disabled) |
|
|
REPL mode — enter prompts, see output |
|
|
Target device |
|
|
Parameter dtype (also: |
The script uses argparse.parse_known_args() — flags it recognizes
(the table above) are consumed as script arguments, and everything
else (e.g. --checkpoint.load_path=..., --model.dim=...) is fed to
load_config as a config override, just like scripts/train.py.
Positional config must come first; the script flags and config
overrides can otherwise appear in any order.
Interactive mode¶
uv run python scripts/generate.py configs/train/7b.toml \
--checkpoint.load_path=checkpoints/7b/step_50000 \
--data.tokenizer_path=meta-llama/Llama-2-7b-hf \
--interactive
A small REPL loop: enter a prompt, see generation, loop. Each call re-runs prefill from scratch — there’s no conversation state between prompts. Handy for qualitative sanity checks while training is ongoing.
How it loads a DCP checkpoint without dist.init_process_group¶
DCP files are multi-rank by default, but
torch.distributed.checkpoint.load supports single-process loading
without initializing a process group:
# scripts/generate.py
state_dict = {"model": model.state_dict()}
dcp.load(state_dict, checkpoint_id=str(ckpt_path))
model.load_state_dict(state_dict["model"])
This is why generate.py is single-GPU even for models that were
trained with FSDP=N and TP=M — DCP handles the resharding on the read
side, loading the full unsharded model onto one device. For a 70B
model that means ~140 GB of bf16 parameters — run on a node with
enough memory, or use FSDP via a multi-GPU loader (not covered by
scripts/generate.py).
The generate() function¶
from kempnerforge.model.generate import generate
output_ids = generate(
model,
prompt_tokens, # (batch, prompt_len)
max_new_tokens,
*,
temperature=1.0,
top_k=0,
top_p=1.0,
eos_token_id=None, # stop when all batch entries emit this token
) # returns (batch, prompt_len + generated_len)
Called with @torch.no_grad(), flips the model to .eval(), restores
training mode on exit. The function does not call
dist.init_process_group — it works on raw tensors, so you can use it
both single-GPU (as scripts/generate.py does) and from within a
training script after FSDP summons parameters (more complex).
Sampling details¶
Temperature → top-k → top-p → multinomial. All in
sample():
if temperature == 0:
return logits.argmax(dim=-1) # greedy
logits = logits / temperature
if top_k > 0:
threshold = logits.topk(top_k, dim=-1).values[:, -1:]
logits = logits.where(logits >= threshold, -inf)
if top_p < 1.0:
sorted_logits, idx = logits.sort(dim=-1, descending=True)
probs = sorted_logits.softmax(dim=-1)
mask = (probs.cumsum(dim=-1) - probs) >= top_p
sorted_logits[mask] = -inf
logits = scatter back to original order
probs = logits.softmax(dim=-1)
return torch.multinomial(probs, 1).squeeze(-1)
sample() is exported — useful if you want to plug your own decoding
loop with a custom sampler (contrastive, typical, etc.) while keeping
the model + KV-cache wiring.
KV cache¶
Without a cache, generating N tokens re-runs attention over the
growing sequence 1, 2, …, N times — O(N²) work for a problem that
should be O(N). The
KVCache
stores per-layer keys and values so each new token only needs one
attention pass over the cached history.
Layout¶
KVCache(
batch_size,
max_seq_len, # pre-allocated up to this length
n_kv_heads, # GQA key heads (not query heads)
head_dim,
dtype,
device,
)
Two pre-allocated tensors of shape
(batch, n_kv_heads, max_seq_len, head_dim). Keys are stored after
RoPE but before GQA expansion — this saves memory by
n_heads / n_kv_heads (4× for the default Llama-style config).
Update¶
def update(k_new, v_new) -> (k_all, v_all):
end = self.seq_len + k_new.shape[2]
self.k[:, :, self.seq_len:end] = k_new
self.v[:, :, self.seq_len:end] = v_new
self.seq_len = end
return self.k[:, :, :end], self.v[:, :, :end]
update returns slices, not copies — the returned tensors alias
into the pre-allocated buffer. Safe inside @torch.no_grad().
Prefill vs decode¶
generate() allocates one KVCache per transformer layer, then:
Prefill: one forward pass with the full prompt. Each layer’s attention calls
kv_cache.update(k, v)with the prompt’s K/V tensors, filling positions[0, prompt_len).Decode loop:
max_new_tokenssteps, each a forward pass on a single-token input. Attention’sk_new,v_neware shape(batch, n_kv_heads, 1, head_dim);updateappends at positionseq_lenand advances it.
If prompt_len + max_new_tokens > model.config.max_seq_len the call
raises — the cache is pre-allocated up to that bound.
Batch generation from Python¶
scripts/generate.py tokenizes one prompt. For many prompts at
once, call generate() directly:
import torch
from transformers import AutoTokenizer
from kempnerforge.model.generate import generate
tokenizer = AutoTokenizer.from_pretrained("gpt2")
prompts = ["Hello, my name is", "The best way to learn Python is"]
# Pad to the same length (left-pad if you want all prompts to end
# together; right-pad is fine if you just want them to start together)
tokenizer.pad_token = tokenizer.eos_token
batch = tokenizer(prompts, return_tensors="pt", padding=True).input_ids
batch = batch.to(device)
output = generate(model, batch, max_new_tokens=50, temperature=0.8, top_p=0.9)
# output: (batch=2, prompt_len + 50)
for row in output:
print(tokenizer.decode(row, skip_special_tokens=True))
Two caveats:
Left-pad or right-pad consistently. The KV cache assumes position 0 is the start of every sequence in the batch, so if you right-pad, the generation will treat pad tokens as a real prefix. For Llama-style tokenizers this is usually fine (pad is excluded from attention via the mask), but for greedy comparison runs, left-pad to line up end positions.
eos_token_idstops the whole batch when all sequences emit EOS. Individual sequences past EOS continue generating garbage until every batch row hits it; filter post-hoc.
What scripts/generate.py doesn’t do¶
No beam search — only sampling or argmax-greedy.
No distributed inference — single-GPU only. For 70B+ you need to write your own FSDP-wrapped inference loop.
No cache reuse across calls in interactive mode — each prompt re-prefills from scratch.
No streaming output — it generates all
max_new_tokensbefore printing. For streaming, loopsample()yourself and print after each token.
See also¶
Training § Generation — the
generate()internals reference.Training § Generation § KV cache —
KVCacheclass docs.End-to-end training run § Generate from the checkpoint — the quickstart version of this page.
Checkpointing § DCP model format — why single-process DCP load works without
init_process_group.scripts/generate.py— the script this page documents.