Generation¶
Autoregressive decoding for research and debug, implemented in
kempnerforge/model/generate.py.
Single-GPU only — not a production serving path.
sample¶
from kempnerforge.model.generate import sample
next_token = sample(logits, temperature=1.0, top_k=0, top_p=1.0)
# (batch, vocab_size) -> (batch,)
Applies, in order:
temperature == 0→logits.argmax(dim=-1)(greedy; short-circuit, none of the filters below run).Temperature scaling →
logits / temperature.Top-k filtering → keep the
top_klargest values per batch row, mask the rest with-inf.top_k=0disables.Top-p (nucleus) filtering → sort descending, keep the smallest prefix whose probabilities sum to <
top_p.top_p=1.0disables.Sample →
torch.multinomial(probs, num_samples=1).squeeze(-1).
The order matters: top-k and top-p both operate on temperature-scaled logits, not raw ones.
You can call sample() standalone for custom decode loops — it’s the
one-shot primitive behind generate().
generate¶
from kempnerforge.model.generate import generate
@torch.no_grad()
def generate(
model: Transformer,
prompt_tokens: torch.Tensor, # (batch, prompt_len)
max_new_tokens: int,
*,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
eos_token_id: int | None = None,
) -> torch.Tensor: # (batch, total_len)
What it does:
Saves
was_training, switchesmodel.eval().Validates
prompt_len + max_new_tokens ≤ model.config.max_seq_len— raisesValueErrorotherwise.Allocates one
KVCacheper transformer layer, sized for the full sequence (batch_size, total_len, n_kv_heads, head_dim), matching the model’s parameter dtype.Prefill: forwards the whole prompt through the model with the KV caches; grabs the last-position logits.
Decode loop: for
max_new_tokensiterations:sample()to pick the next token.If
eos_token_idis set, OR it into a per-rowdonemask; break early ifdone.all().Forward the single sampled token (
next_token.unsqueeze(1)) with the same KV caches to get the next logits.
Restores
model.train()ifwas_training.Returns
torch.cat([prompt_tokens, generated], dim=1).
KV cache¶
Imported from
kempnerforge.model.attention.
generate() allocates one cache per layer and passes them through the
model via model(..., kv_caches=kv_caches).
Size: batch_size × total_len × n_kv_heads × head_dim × 2 (K+V) × dtype_bytes
per layer, times n_layers. For a 7B Llama-3 at batch=1, total_len=8192, bf16 that’s ~2 GB of KV cache — easily fits on a single H200 but
watch the budget for larger batches or longer contexts.
Stop criteria¶
Two:
max_new_tokens— hard limit, always enforced.eos_token_id— optional. When provided, generation stops early only when every row in the batch has emitted EOS. A partial-batch early-stop would require padding, whichgenerate()doesn’t implement. If that matters, decode batches of size 1 or post-process the output yourself.
Standalone CLI: scripts/generate.py¶
uv run python scripts/generate.py configs/train/7b.toml \
--checkpoint.load_path /path/to/checkpoint \
--prompt "The capital of France is" \
--max_tokens 64 \
--temperature 0.7 \
--top_p 0.9
Interactive REPL:
uv run python scripts/generate.py configs/train/7b.toml \
--checkpoint.load_path /path/to/checkpoint \
--interactive
The CLI loads the model via DCP (not through the full distributed
stack — no FSDP2, no init_distributed), tokenizes the prompt with
the HF tokenizer specified in data.tokenizer_path, calls
generate(), decodes with skip_special_tokens=True, and prints the
generated suffix.
Args (beyond config / overrides):
Flag |
Default |
Meaning |
|---|---|---|
|
(required unless |
Prompt string |
|
|
Max new tokens |
|
|
Sampling temperature; |
|
|
Top-k cutoff; |
|
|
Nucleus threshold; |
|
|
REPL mode |
|
auto |
|
|
|
|
Limitations¶
Single-GPU only — no FSDP / TP / PP model support in
generate(). For distributed inference, export to HF and use vLLM or similar.No KV-cache reuse across calls — each
generate()call allocates fresh caches. Good for research, wasteful for serving.No speculative decoding, no beam search, no repetition penalty — this is a minimal sampler for evaluating training progress, not a generation stack. If you need any of those, fork
generate.py.
See also¶
Training loop § Periodic work — training-time eval doesn’t use
generate(), only loss. Usegenerate()for qualitative spot-checks between runs.Checkpointing — how
scripts/generate.pyloads the DCP checkpoint it decodes from.Model § Attention paths — the KV-cache path vs the training path through attention.