Source code for kempnerforge.model.init

"""Weight initialization strategies for KempnerForge models."""

from __future__ import annotations

import math

import torch.nn as nn

from kempnerforge.config.schema import ModelConfig


[docs] def init_weights(model: nn.Module, config: ModelConfig) -> None: """Apply standard initialization to all parameters in a model. Strategy (following GPT-2/Llama conventions): - Linear layers: normal(0, 0.02) - Embedding layers: normal(0, 0.02) - Residual output projections (o_proj, down_proj): scaled by 1/sqrt(2 * n_layers) - Norm layers: weight=1 (already default) """ std = config.init_std residual_scale = 1.0 / math.sqrt(2.0 * config.n_layers) for name, param in model.named_parameters(): if param.is_meta: continue if param.dim() < 2: # Bias and norm parameters: leave at default (zeros / ones) continue # Residual projections get scaled init to prevent signal growth if name.endswith(("o_proj.weight", "down_proj.weight")): nn.init.normal_(param, mean=0.0, std=std * residual_scale) else: nn.init.normal_(param, mean=0.0, std=std)