Model¶
The forward pass block by block, with pointers to the code that implements
each piece. KempnerForge’s Transformer is a pre-norm Llama-style
decoder: token embedding → N transformer blocks → final RMSNorm →
output head. All components live under
kempnerforge/model/.
High-level flow¶
tokens (batch, seq_len)
↓ TokenEmbedding
hidden (batch, seq_len, dim)
↓ for each of n_layers TransformerBlocks:
↓ x = x + Attention(RMSNorm(x), cos, sin)
↓ x = x + MLP(RMSNorm(x)) (SwiGLUMLP or MoEMLP)
↓ final RMSNorm
↓ OutputHead (nn.Linear, bias=False)
logits (batch, seq_len, vocab_size)
Pre-norm means LayerNorm happens before attention and MLP, not after.
Implemented in
kempnerforge/model/transformer.py.
Token embedding¶
TokenEmbedding wraps nn.Embedding(vocab_size, dim). It is optional —
None on pipeline-parallel middle stages that only receive hidden states.
kempnerforge/model/embedding.py.
RoPE (rotary position embedding)¶
Positions are injected inside each attention block, not added to the
embedding. precompute_rope_frequencies(head_dim, max_seq_len, theta)
returns two tables of shape (max_seq_len, head_dim // 2):
cos— cosines ofposition * freqsin— sines of the same
apply_rope(x, cos, sin) splits x along head_dim into two halves and
rotates them:
[x1, x2] → [x1·cos − x2·sin, x2·cos + x1·sin]
The rotation uses real-valued sin/cos (not complex arithmetic) so
DTensor metadata survives under SequenceParallel — a comment in
kempnerforge/model/position.py
records that .float() stripped the DTensor wrapper in an earlier
version.
During generation, Transformer.forward slices cos/sin starting at
kv_caches[0].seq_len so incremental tokens get the correct absolute
positions.
Attention¶
Attention in
kempnerforge/model/attention.py
implements grouped-query attention (GQA).
Projections¶
Four nn.Linear(..., bias=False) projections:
Name |
Shape |
|---|---|
|
|
|
|
|
|
|
|
GQA is configured by the ratio n_heads / n_kv_heads:
n_kv_heads == n_heads→ multi-head attention (MHA)n_kv_heads == 1→ multi-query attention (MQA)anything in between → GQA (
configs/model/llama_7b.tomlusesn_heads=32, n_kv_heads=8)
After Q/K/V are computed, KV heads are expanded to n_heads via
repeat_interleave(n_rep, dim=1) so SDPA sees aligned shapes.
QK-Norm¶
When qk_norm=True, per-head RMSNorm(head_dim) is applied to Q and K
before RoPE. Stabilizes attention logits at scale (Gemma, DeepSeek-V3).
Three attention paths¶
Packed sequences (
doc_idspassed in): a block-diagonal causal mask isolates documents within one packed sequence. Built fromdoc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)intersected with the causal triangle, then passed toF.scaled_dot_product_attention(..., attn_mask=...).Standard causal (training and prefill, no
doc_ids): SDPA withis_causal=True.Single-token decode (
seq_len == 1with a KV cache): no mask — the query attends to all cached positions.is_causal=Truehere would incorrectly restrict attention to only the first key.
A fourth, slower path fires only when capture_attention_weights=True:
_attention_with_weights computes softmax(Q·Kᵀ / √d) manually so
attention weights can be extracted for interpretability. Don’t enable it
for training runs.
KV cache placement¶
KV cache update happens after RoPE, before GQA expansion. That
ordering is intentional: cached keys already carry their rotary positions,
and the cache stores n_kv_heads tensors, not the expanded n_heads
copies.
SDPA backend¶
sdpa_backend (default "auto") selects flash / mem-efficient / cudnn /
math via a context manager. "auto" lets PyTorch pick.
MLP¶
Two dense variants, keyed by activation:
SwiGLUMLP(activation="silu", Llama-style) — three linears:gate_proj + up_proj → SiLU(gate) * up → down_proj.StandardMLP("gelu"or"relu") — two linears:up_proj → activation → down_proj.
kempnerforge/model/mlp.py.
The helper build_mlp(dim, hidden_dim, activation) looks the variant up
in the "mlp" registry.
MoE variant¶
When ModelConfig.is_moe is True, layers where
(layer_idx + 1) % moe_frequency == 0 swap SwiGLUMLP for MoEMLP (see
kempnerforge/model/moe.py).
moe_frequency=1 makes every layer MoE; moe_frequency=2 interleaves
dense and MoE layers, with layer 0 staying dense. See
docs/moe/ for the full MoE stack.
RMSNorm¶
RMSNorm(dim, eps=1e-5) in
kempnerforge/model/norm.py.
A single learned scale vector (self.weight, init 1s), no bias, no mean
subtraction:
x.float() * rsqrt(mean(x² , dim=-1, keepdim=True) + eps) * weight → cast back
The float32 cast-and-cast-back is deliberate — it prevents the variance statistic from overflowing in bf16 for long sequences.
LayerNorm is also registered under the "norm" registry key
"layernorm" if you want it; everything else in the codebase assumes
"rmsnorm".
Output head¶
OutputHead is nn.Linear(dim, vocab_size, bias=False). Optional, like
TokenEmbedding — None on pipeline-parallel non-final stages.
When ModelConfig.tie_embeddings=True (and both layers exist), the head
shares its weight with the embedding via OutputHead.tie_weights(emb)
(self.proj.weight = emb.embedding.weight).
TransformerBlock¶
One block’s forward is five lines:
x = x + self.attention(self.attention_norm(x), rope_cos, rope_sin,
kv_cache=kv_cache, doc_ids=doc_ids)
x = x + self.mlp(self.mlp_norm(x))
return x
Pre-norm + residual, both for attention and for MLP. No
drop_path/stochastic_depth; no learnable scale on the residual branch.
Transformer assembly¶
Transformer owns:
token_embedding(optional)layers: nn.ModuleDict[str, TransformerBlock]— keyed bystr(i)instead ofnn.ModuleListso checkpoint FQNs are stable under DCPnorm— finalRMSNormoutput_head(optional)_rope_cos,_rope_sin— precomputed frequency tables stored as plain attributes (not buffers) somodel.to(bf16)doesn’t cast them to bf16
Extra helpers on Transformer:
init_weights_and_freqs()— called aftermodel.to_empty(device=...)to fill parameter values after meta-device materializationset_moe_step(step, max_steps)— forwards the current step to everySigmoidTopKRouterfor the adaptive bias scheduleget_moe_aux_loss()— sums the per-layer MoE aux loss; returns 0 for dense modelsget_expert_counts()— returns{layer_idx: counts}for MoE layers; empty for dense
The whole class is registered under the "model" registry key
"transformer", which is what the training loop asks for.
Weight initialization¶
init_weights(model, config) in
kempnerforge/model/init.py
applies the GPT-2 / Llama convention:
2-D parameters (
Linearweights,Embedding) getnormal(0, config.init_std)— default0.02.Parameters whose name ends in
o_proj.weightordown_proj.weightget scaled by1 / sqrt(2 * n_layers). Without this, the residual stream’s variance grows linearly with depth.Biases and norm weights are left at their defaults (zeros and ones).
Meta-device parameters are skipped —
init_weights_and_freqs()runs after materialization.
Registry keys¶
Swap pieces without touching the transformer code by registering a new builder under the right category:
Category |
Existing keys |
Builder |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
See Configuration for the step-by-step pattern for registering new components.
Where to read next¶
Parallelism order — how TP, EP, FP8, AC, and FSDP2 are layered on top of this module tree.
Data flow — the training-loop path that feeds this forward pass.