Source code for kempnerforge.model.mlp
"""Feed-forward network implementations for KempnerForge models."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from kempnerforge.config.registry import registry
[docs]
class SwiGLUMLP(nn.Module):
"""SwiGLU feed-forward network (Llama-style).
Architecture: gate_proj + up_proj → SiLU(gate) * up → down_proj
Uses 3 weight matrices instead of 2, with SiLU gating.
"""
[docs]
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
[docs]
class StandardMLP(nn.Module):
"""Standard two-layer MLP with configurable activation.
Architecture: linear → activation → linear
"""
[docs]
def __init__(self, dim: int, hidden_dim: int, activation: str = "gelu") -> None:
super().__init__()
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
activations = {"gelu": F.gelu, "relu": F.relu, "silu": F.silu}
if activation not in activations:
raise ValueError(f"Unknown activation: {activation!r}")
self._activation = activations[activation]
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self._activation(self.up_proj(x)))
def _build_swiglu(dim: int, hidden_dim: int) -> SwiGLUMLP:
return SwiGLUMLP(dim, hidden_dim)
def _build_standard_gelu(dim: int, hidden_dim: int) -> StandardMLP:
return StandardMLP(dim, hidden_dim, activation="gelu")
def _build_standard_relu(dim: int, hidden_dim: int) -> StandardMLP:
return StandardMLP(dim, hidden_dim, activation="relu")
registry.register("mlp", "swiglu", _build_swiglu)
registry.register("mlp", "standard_gelu", _build_standard_gelu)
registry.register("mlp", "standard_relu", _build_standard_relu)
# Map activation config names to registry keys
_ACTIVATION_TO_MLP = {"silu": "swiglu", "gelu": "standard_gelu", "relu": "standard_relu"}
[docs]
def build_mlp(dim: int, hidden_dim: int, activation: str = "silu") -> nn.Module:
"""Build an MLP by activation name.
SiLU activation uses SwiGLU (3 matrices); others use standard MLP (2 matrices).
"""
key = _ACTIVATION_TO_MLP.get(activation, activation)
builder = registry.get("mlp", key)
return builder(dim, hidden_dim)