Source code for kempnerforge.model.embedding
"""Token embedding and output head for KempnerForge models."""
from __future__ import annotations
import torch
import torch.nn as nn
[docs]
class TokenEmbedding(nn.Module):
"""Token embedding layer.
Can be disabled (returns input unchanged) for pipeline parallelism
middle stages where the embedding lives on a different stage.
"""
[docs]
def __init__(self, vocab_size: int, dim: int) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, dim)
[docs]
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
"""Embed token ids to vectors.
Args:
tokens: Integer tensor of shape (batch, seq_len).
Returns:
Tensor of shape (batch, seq_len, dim).
"""
return self.embedding(tokens)
[docs]
class OutputHead(nn.Module):
"""Linear output projection from hidden dim to vocab size.
Produces logits (no softmax). Can optionally share weights with an embedding layer.
"""
[docs]
def __init__(self, dim: int, vocab_size: int) -> None:
super().__init__()
self.proj = nn.Linear(dim, vocab_size, bias=False)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Project hidden states to logits.
Args:
x: Tensor of shape (batch, seq_len, dim).
Returns:
Logits tensor of shape (batch, seq_len, vocab_size).
"""
return self.proj(x)
[docs]
def tie_weights(self, embedding: TokenEmbedding) -> None:
"""Share the output projection weight with the embedding layer."""
self.proj.weight = embedding.embedding.weight