kempnerforge.training.loss

Loss function registry for KempnerForge.

Registers loss functions and provides build_loss_fn() to compose them with config-driven options (chunk size, z-loss). Follows the same builder pattern as build_optimizer.

Functions

build_loss_fn(config)

Build a composed loss function from training config.

chunked_cross_entropy_loss(logits, labels[, ...])

Cross-entropy computed in token-dimension chunks.

cross_entropy_loss(logits, labels)

Standard cross-entropy loss for language modeling.

z_loss(logits, weight)

Logit magnitude regularizer (PaLM / Gemini).

kempnerforge.training.loss.cross_entropy_loss(logits, labels)[source]

Standard cross-entropy loss for language modeling.

Uses ignore_index=-100 so packed-sequence boundary tokens (labeled -100 by the dataset) are excluded from the loss. When no -100 labels are present, this has zero overhead.

Parameters:
Return type:

torch.Tensor

kempnerforge.training.loss.chunked_cross_entropy_loss(logits, labels, chunk_size=4096)[source]

Cross-entropy computed in token-dimension chunks.

Chunks along the token dimension and uses PyTorch’s fused CE kernel per chunk, avoiding an explicit float32 materialization of the full logit tensor. For Llama-3 7B (vocab=128K, batch=4, seq=4096), the manual logsumexp path would create a ~8 GB float32 copy; this implementation avoids that entirely.

Uses ignore_index=-100 so packed-sequence boundary tokens are excluded. When no -100 labels are present, this has zero overhead.

Note: the input logit tensor (B*S, V) is still fully materialized by the model’s output head before reaching this function. For deeper savings (never materializing the full logit tensor), the output projection itself must be chunked in the model forward pass — a future enhancement.

Parameters:
  • logits (torch.Tensor) – (batch, seq, vocab) or (tokens, vocab).

  • labels (torch.Tensor) – (batch, seq) or (tokens,).

  • chunk_size (int) – Number of tokens per chunk.

Return type:

torch.Tensor

kempnerforge.training.loss.z_loss(logits, weight)[source]

Logit magnitude regularizer (PaLM / Gemini).

Penalizes large logit magnitudes to prevent logit drift that causes NaN/divergence in long training runs. Negligible compute cost.

Formula: weight * mean(logsumexp(logits, dim=-1) ** 2)

Parameters:
  • logits (torch.Tensor) – Model output logits, shape (batch, seq, vocab) or (tokens, vocab).

  • weight (float) – Regularization weight (PaLM uses 1e-4).

Returns:

Scalar z-loss term to add to the main loss.

Return type:

torch.Tensor

kempnerforge.training.loss.build_loss_fn(config)[source]

Build a composed loss function from training config.

Follows the build_optimizer pattern: config in, callable out. Binds chunk_size for chunked CE and composes z-loss, so the caller gets a clean (logits, labels) -> Tensor interface.

Return type:

Callable[[torch.Tensor, torch.Tensor], torch.Tensor]