Source code for kempnerforge.model.position
"""Rotary Position Embedding (RoPE) for KempnerForge models.
Uses real-valued sin/cos rotation (not complex arithmetic) for
compatibility with DTensor and SequenceParallel.
"""
from __future__ import annotations
import torch
[docs]
def precompute_rope_frequencies(
head_dim: int,
max_seq_len: int,
theta: float = 10000.0,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos/sin RoPE frequency tables.
Args:
head_dim: Dimension per attention head (must be even).
max_seq_len: Maximum sequence length to precompute.
theta: Base frequency (10000.0 for standard RoPE).
device: Device to place the tensor on.
Returns:
Tuple of (cos, sin) tensors, each shape (max_seq_len, head_dim // 2).
"""
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Frequency for each dimension pair: theta^{-2i/d} for i in [0, d/2)
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
# Position indices
positions = torch.arange(max_seq_len, device=device)
# Outer product: (max_seq_len, head_dim // 2)
freqs_table = torch.outer(positions, freqs)
return freqs_table.cos(), freqs_table.sin()
[docs]
def apply_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""Apply rotary position embeddings using real-valued rotation.
Args:
x: Input tensor of shape (..., seq_len, head_dim).
cos: Cosine frequencies, shape (seq_len, head_dim // 2).
sin: Sine frequencies, shape (seq_len, head_dim // 2).
Returns:
Tensor with RoPE applied, same shape and dtype as input.
"""
# Split head dim into two halves for paired rotation
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
# Broadcast cos/sin to match x shape: (seq_len, d) → (..., seq_len, d)
# Cast cos/sin to x's dtype (bf16) instead of casting x to float32,
# because .float() strips DTensor metadata needed for SequenceParallel.
ndim = x.ndim
shape = [1] * (ndim - 2) + list(cos.shape)
cos = cos.view(*shape).to(x.dtype)
sin = sin.view(*shape).to(x.dtype)
# Rotation: [x1, x2] → [x1*cos - x2*sin, x2*cos + x1*sin]
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)