kempnerforge.model.generate¶
Autoregressive text generation with KV-cache.
Supports greedy decoding, top-k, top-p (nucleus) sampling, and temperature scaling. Designed for single-GPU research/debug use, not production serving.
Functions
|
Sample next token from logits. |
- kempnerforge.model.generate.sample(logits, temperature=1.0, top_k=0, top_p=1.0)[source]¶
Sample next token from logits.
Applies temperature scaling, top-k filtering, and nucleus (top-p) sampling in that order. Can be called independently for custom generation loops.
- Parameters:
logits (torch.Tensor) – (batch, vocab_size) unnormalized log-probabilities.
temperature (float) – Sampling temperature. 0 = greedy.
top_k (int) – Keep only top-k tokens. 0 = no filtering.
top_p (float) – Nucleus sampling threshold. 1.0 = no filtering.
- Returns:
Token ids, shape (batch,).
- Return type:
- kempnerforge.model.generate.generate(model, prompt_tokens, max_new_tokens, *, temperature=1.0, top_k=0, top_p=1.0, eos_token_id=None)¶
Generate tokens autoregressively with KV-cache.
- Parameters:
model (Transformer) – Transformer model (set to eval mode during generation).
prompt_tokens (torch.Tensor) – Input token ids, shape (batch, prompt_len).
max_new_tokens (int) – Maximum number of new tokens to generate.
temperature (float) – Sampling temperature. 0 = greedy decoding.
top_k (int) – Top-k filtering. 0 = disabled.
top_p (float) – Nucleus sampling threshold. 1.0 = disabled.
eos_token_id (int | None) – Stop when all sequences produce this token.
- Returns:
Full sequence (prompt + generated), shape (batch, total_len).
- Return type: