Source code for kempnerforge.model.cross_attention

"""Cross-attention block for VLM Cross-Attention architecture.

Text queries cross-attend to image keys/values. Inserted between
regular ``TransformerBlock``s at a configurable cadence; the residual
stream itself carries text only, with image K/V flowing in as a side
channel via ``ModalityContext.image_features``.

Differences from ``Attention``:

- No causal mask on the image axis. Image tokens have no temporal
  order that aligns with text positions, so the full image K/V set is
  visible from every text position.
- No RoPE. RoPE encodes relative position along a single axis;
  cross-attention spans two axes (text Q positions vs image K/V
  positions) with no shared coordinate, so RoPE is dropped on both
  sides. The text axis already has RoPE applied inside each
  preceding ``TransformerBlock``; cross-attention just queries off
  the resulting hidden state.
- Output projection ``o_proj`` and the block's MLP output projection
  are zero-initialized so the block starts as identity. This matches
  Llama-3-V's warm-start: a CA arch added to an existing text-only
  checkpoint contributes zero gradient at step 0 and learns a non-zero
  contribution from there.
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from kempnerforge.model.mlp import build_mlp
from kempnerforge.model.norm import build_norm


[docs] class CrossAttention(nn.Module): """Text Q × image K/V cross-attention with optional GQA."""
[docs] def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int | None = None, ) -> None: super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim or (dim // n_heads) self.n_rep = n_heads // n_kv_heads # GQA repetition factor self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(n_heads * self.head_dim, dim, bias=False) # Zero-init output projection so the block is identity at construction. nn.init.zeros_(self.o_proj.weight)
[docs] def forward( self, x: torch.Tensor, image_features: torch.Tensor, image_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass. Args: x: Text hidden state, shape ``(batch, seq_len, dim)``. image_features: Image K/V source, shape ``(batch, num_image_tokens, dim)``. image_mask: Optional bool mask, shape ``(batch, num_image_tokens)``; ``True`` = attend, ``False`` = mask out. ``None`` = all image tokens attended to. Returns: Output tensor of shape ``(batch, seq_len, dim)``. """ batch, seq_len, _ = x.shape num_image_tokens = image_features.shape[1] # Q from text, K/V from image features. Use -1 for head count so the # view works under tensor parallelism (ColwiseParallel shards out_features). q = self.q_proj(x).view(batch, seq_len, -1, self.head_dim) k = self.k_proj(image_features).view(batch, num_image_tokens, -1, self.head_dim) v = self.v_proj(image_features).view(batch, num_image_tokens, -1, self.head_dim) # Transpose to (batch, heads, seq, head_dim) for SDPA q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Expand K/V heads for GQA if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) # Build SDPA attn_mask from image_mask if present. # image_mask: (B, N) bool, True = attend. SDPA accepts a bool mask # broadcastable to (B, n_heads, S_q, S_kv); shape (B, 1, 1, N) does # the right thing across heads and text Q positions. attn_mask = None if image_mask is not None: attn_mask = image_mask.view(batch, 1, 1, num_image_tokens) # Cross-attention: no causal mask on the image axis. out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1) return self.o_proj(out)
[docs] class CrossAttentionBlock(nn.Module): """Pre-norm wrapper: ``CrossAttention`` + residual + MLP + residual. Mirrors ``TransformerBlock``'s outer shape so the freeze + FSDP + DCP plumbing applies uniformly. The MLP's output projection is also zero-initialized so the whole block is identity at construction. """
[docs] def __init__( self, dim: int, n_heads: int, n_kv_heads: int, ffn_hidden_dim: int, norm_type: str = "rmsnorm", activation: str = "silu", ) -> None: super().__init__() self.attn_norm = build_norm(norm_type, dim) self.attn = CrossAttention(dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads) self.mlp_norm = build_norm(norm_type, dim) self.mlp = build_mlp(dim=dim, hidden_dim=ffn_hidden_dim, activation=activation) # Zero-init MLP output projection. SwiGLU uses down_proj; StandardMLP # also uses down_proj. Both are nn.Linear, so set the weight to zero. # type ignore: build_mlp returns nn.Module statically; both concrete # subclasses (SwiGLUMLP, StandardMLP) expose .down_proj.weight, but # pyright cannot narrow through build_mlp's return type. nn.init.zeros_(self.mlp.down_proj.weight) # type: ignore[union-attr]
[docs] def forward( self, x: torch.Tensor, image_features: torch.Tensor, image_mask: torch.Tensor | None = None, ) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), image_features, image_mask) x = x + self.mlp(self.mlp_norm(x)) return x