"""Multi-head attention with Grouped-Query Attention (GQA) support.
GQA is the general case:
- n_kv_heads == n_heads → standard Multi-Head Attention (MHA)
- n_kv_heads == 1 → Multi-Query Attention (MQA)
- 1 < n_kv_heads < n_heads → Grouped-Query Attention (GQA)
"""
from __future__ import annotations
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from kempnerforge.model.norm import RMSNorm
from kempnerforge.model.position import apply_rope
_SDPA_BACKENDS = {
"flash": SDPBackend.FLASH_ATTENTION,
"efficient": SDPBackend.EFFICIENT_ATTENTION,
"cudnn": SDPBackend.CUDNN_ATTENTION,
"math": SDPBackend.MATH,
}
[docs]
class KVCache:
"""Pre-allocated KV cache for autoregressive generation.
Stores key and value tensors for all previous positions, enabling
incremental decoding without recomputing attention over the full sequence.
Keys are stored after RoPE application but before GQA expansion.
"""
[docs]
def __init__(
self,
batch_size: int,
max_seq_len: int,
n_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
self.k = torch.zeros(
batch_size, n_kv_heads, max_seq_len, head_dim, dtype=dtype, device=device
)
self.v = torch.zeros(
batch_size, n_kv_heads, max_seq_len, head_dim, dtype=dtype, device=device
)
self.seq_len = 0
[docs]
def update(self, k_new: torch.Tensor, v_new: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Append new key/value entries and return full cached tensors.
Args:
k_new: New keys, shape (batch, n_kv_heads, new_seq_len, head_dim).
v_new: New values, shape (batch, n_kv_heads, new_seq_len, head_dim).
Returns:
Tuple of (all_keys, all_values), each
(batch, n_kv_heads, total_seq_len, head_dim).
"""
new_len = k_new.shape[2]
end = self.seq_len + new_len
self.k[:, :, self.seq_len : end] = k_new
self.v[:, :, self.seq_len : end] = v_new
self.seq_len = end
return self.k[:, :, :end], self.v[:, :, :end]
[docs]
class Attention(nn.Module):
"""Grouped-Query Attention with RoPE and SDPA."""
[docs]
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int | None = None,
qk_norm: bool = False,
sdpa_backend: str = "auto",
) -> None:
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim or (dim // n_heads)
self.n_rep = n_heads // n_kv_heads # GQA repetition factor
self.sdpa_backend = sdpa_backend
# Projections
self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * self.head_dim, dim, bias=False)
# Per-head QK normalization (Gemma, DeepSeek-V3)
self.q_norm = RMSNorm(self.head_dim) if qk_norm else None
self.k_norm = RMSNorm(self.head_dim) if qk_norm else None
# Attention weight capture (analysis only — not for training)
self.capture_attention_weights = False
self.last_attention_weights: torch.Tensor | None = None
def _sdpa_context(self):
"""Return a context manager that forces a specific SDPA backend.
When sdpa_backend="auto", returns a no-op and lets PyTorch's
heuristics select the backend.
"""
if self.sdpa_backend == "auto":
return contextlib.nullcontext()
return sdpa_kernel(_SDPA_BACKENDS[self.sdpa_backend])
[docs]
def forward(
self,
x: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
*,
kv_cache: KVCache | None = None,
doc_ids: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tensor of shape (batch, seq_len, dim).
rope_cos: RoPE cosine frequencies, shape (seq_len, head_dim // 2).
rope_sin: RoPE sine frequencies, shape (seq_len, head_dim // 2).
kv_cache: Optional KV cache for incremental generation.
doc_ids: Optional per-token document IDs for packed sequences,
shape (batch, seq_len). When provided, constructs a block-diagonal
causal mask so tokens only attend within their document.
key_padding_mask: Optional per-key validity mask, shape
(batch, seq_len); ``True`` = attend, ``False`` = drop (e.g. the
visual tokens of padded video frames). Combined with the causal
(and doc) mask; fully-masked query rows are unmasked to keep
softmax finite.
Returns:
Output tensor of shape (batch, seq_len, dim).
"""
batch, seq_len, _ = x.shape
# Project to Q, K, V
# Use -1 for head count so the view works under tensor parallelism
# (ColwiseParallel shards the output features, reducing the local head count)
q = self.q_proj(x).view(batch, seq_len, -1, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, -1, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, -1, self.head_dim)
# QK-Norm: normalize Q and K per-head before RoPE (stabilizes attention logits)
if self.q_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k) # type: ignore[reportOptionalCall]
# Transpose to (batch, heads, seq_len, head_dim) for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Apply RoPE to Q and K
q = apply_rope(q, rope_cos, rope_sin)
k = apply_rope(k, rope_cos, rope_sin)
# Update KV cache (after RoPE, before GQA expansion)
if kv_cache is not None:
k, v = kv_cache.update(k, v)
# Expand KV heads for GQA: (batch, n_kv_heads, seq, dim) → (batch, n_heads, seq, dim)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
# Attention masking strategy:
# 1. Packed sequences (doc_ids provided): block-diagonal causal mask that
# isolates documents from each other within a packed sequence.
# 2. Standard causal: needed for training and prefill (Q seq == K seq).
# 3. Single-token decode (seq_len=1 with KV cache): no mask needed — the
# query attends to all cached positions (is_causal=True would incorrectly
# restrict attention to only the first key position).
if self.capture_attention_weights:
# Manual attention for weight extraction (analysis only, not for training)
out, attn_weights = self._attention_with_weights(
q, k, v, seq_len, doc_ids, kv_cache, key_padding_mask
)
self.last_attention_weights = attn_weights.detach().cpu()
elif doc_ids is not None or key_padding_mask is not None:
# An explicit attn_mask is not a FlashAttention-2 shape, so SDPA falls
# back to the mem-efficient/math kernel here. The image-prefix video
# arches (Joint-Decoder/MoMa here, MoT in mot.py) always pass a
# key_padding_mask (all-True when unpadded), so their self-attention
# always takes this branch -- losing FA2 and materializing a (B, 1, S, S)
# mask even for fully-decoded clips. (Cross-Attention sets no
# key_padding_mask on this text self-attention -- it masks padded image
# K/V in the cross-attention blocks instead -- so it keeps FA2 here.)
# Deliberate: always-masking is torch.compile / DP-friendly (one graph,
# no host sync). Recovering FA2 for unpadded batches (or moving to
# FlexAttention) is a follow-up.
#
# Asserts no kv_cache: neither doc_ids (packed training) nor
# key_padding_mask (VLM video) co-occurs with decode today, and this
# branch's full-sequence causal mask would mis-handle a cached
# (seq_len=1) decode rather than attend to all cached positions.
assert kv_cache is None, (
"doc_ids / key_padding_mask are not supported with kv_cache decode "
"(would build an incorrect causal mask)."
)
seq_len_kv = k.shape[2]
# Explicit bool mask: causal, AND same-document (doc_ids), AND valid
# keys (key_padding_mask, e.g. dropping padded video frames' tokens).
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril()
attn_mask = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv)
if doc_ids is not None:
doc_mask = (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1) # (B,1,S,S)
attn_mask = attn_mask & doc_mask
if key_padding_mask is not None:
attn_mask = attn_mask & key_padding_mask.view(batch, 1, 1, seq_len_kv)
# NaN guard: a query row with no reachable valid key (e.g. the leading
# positions of an all-padded / undecodable clip) would softmax over all
# -inf -> NaN. Unmask such rows; their outputs are discarded (trimmed by
# output_slice, or the clip's labels are all -100).
attn_mask = attn_mask | ~attn_mask.any(dim=-1, keepdim=True)
with self._sdpa_context():
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
else:
is_causal = kv_cache is None or seq_len > 1
with self._sdpa_context():
out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
# Reshape back: (batch, n_heads, seq_len, head_dim) → (batch, seq_len, dim)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.o_proj(out)
def _attention_with_weights(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len: int,
doc_ids: torch.Tensor | None,
kv_cache: KVCache | None,
key_padding_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute attention output and weights manually (for analysis).
SDPA does not return attention weights, so this path computes
softmax(QK^T / sqrt(d)) explicitly. Only use for interpretability —
it is slower and uses more memory than the fused SDPA path.
Returns:
Tuple of (output, attention_weights).
output: (batch, n_heads, seq_len, head_dim)
attention_weights: (batch, n_heads, seq_q, seq_k)
"""
scale = self.head_dim**-0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
seq_len_kv = k.shape[2]
if doc_ids is not None or key_padding_mask is not None:
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).tril()
valid = causal.unsqueeze(0).unsqueeze(0) # (1, 1, S, S_kv)
if doc_ids is not None:
valid = valid & (doc_ids.unsqueeze(2) == doc_ids.unsqueeze(1)).unsqueeze(1)
if key_padding_mask is not None:
valid = valid & key_padding_mask.view(q.shape[0], 1, 1, seq_len_kv)
valid = valid | ~valid.any(dim=-1, keepdim=True) # NaN guard (see forward)
attn = attn.masked_fill(~valid, float("-inf"))
elif kv_cache is None or seq_len > 1:
causal = torch.ones(seq_len, seq_len_kv, dtype=torch.bool, device=q.device).triu(
diagonal=1
)
attn = attn.masked_fill(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
attn_weights = F.softmax(attn, dim=-1)
out = torch.matmul(attn_weights, v)
return out, attn_weights