Source code for kempnerforge.model.norm

"""Normalization layers for KempnerForge models."""

from __future__ import annotations

import torch
import torch.nn as nn

from kempnerforge.config.registry import registry


[docs] class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (Llama-style). Simpler and faster than LayerNorm — no mean subtraction, no bias. """
[docs] def __init__(self, dim: int, eps: float = 1e-5) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: # float32 for numerical stability, then cast back dtype = x.dtype x = x.float() norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return (norm * self.weight).to(dtype)
def _build_rmsnorm(dim: int, eps: float = 1e-5) -> RMSNorm: return RMSNorm(dim, eps=eps) def _build_layernorm(dim: int, eps: float = 1e-5) -> nn.LayerNorm: return nn.LayerNorm(dim, eps=eps) registry.register("norm", "rmsnorm", _build_rmsnorm) registry.register("norm", "layernorm", _build_layernorm)
[docs] def build_norm(norm_type: str, dim: int, eps: float = 1e-5) -> nn.Module: """Build a normalization layer by name.""" builder = registry.get("norm", norm_type) return builder(dim, eps=eps)