Source code for kempnerforge.model.init

"""Weight initialization strategies for KempnerForge models."""

from __future__ import annotations

import math

import torch.nn as nn

from kempnerforge.config.schema import ModelConfig
from kempnerforge.model.moe import MoEMLP
from kempnerforge.model.mot import MoTAttention, MoTBlock


[docs] def init_weights(model: nn.Module, config: ModelConfig) -> None: """Apply standard initialization to all parameters in a model. Strategy (following GPT-2/Llama conventions): - Linear layers: normal(0, 0.02) - Embedding layers: normal(0, 0.02) - Residual output projections (o_proj, down_proj): scaled by 1/sqrt(2 * n_layers) - Cross-attention block residuals: zero-initialized (identity-at-init warm-start) - MoT per-modality residual projections: zero-initialized (identity-at-construction) - Norm layers: weight=1 (already default) """ std = config.init_std residual_scale = 1.0 / math.sqrt(2.0 * config.n_layers) for name, param in model.named_parameters(): if param.is_meta: continue if param.dim() < 2: # Bias and norm parameters: leave at default (zeros / ones) continue # Cross-attention block residuals get zero-init so the block is # identity at construction; downstream training learns a # non-zero contribution from there. if name.startswith("cross_attention_layers.") and name.endswith( ("o_proj.weight", "down_proj.weight") ): nn.init.zeros_(param) # Residual projections elsewhere get scaled init to prevent signal growth elif name.endswith(("o_proj.weight", "down_proj.weight")): nn.init.normal_(param, mean=0.0, std=std * residual_scale) else: nn.init.normal_(param, mean=0.0, std=std) # MoT-specific zero-init for identity-at-construction residual. # These FQNs do not match endswith("o_proj.weight") (per-modality # ModuleDict nesting puts {modality} after o_proj), so they are # handled by a second pass over modules. for module in model.modules(): if isinstance(module, MoTAttention): for m in module.modalities: nn.init.zeros_(module.o_proj[m].weight) # type: ignore[reportArgumentType] elif isinstance(module, MoTBlock): for m in module.modalities: mlp_m = module.mlp[m] if isinstance(mlp_m, MoEMLP): continue if hasattr(mlp_m, "down_proj"): nn.init.zeros_(mlp_m.down_proj.weight) # type: ignore[union-attr]