kempnerforge.model.cross_attention

Cross-attention block for VLM Cross-Attention architecture.

Text queries cross-attend to image keys/values. Inserted between regular TransformerBlock``s at a configurable cadence; the residual stream itself carries text only, with image K/V flowing in as a side channel via ``ModalityContext.image_features.

Differences from Attention:

  • No causal mask on the image axis. Image tokens have no temporal order that aligns with text positions, so the full image K/V set is visible from every text position.

  • No RoPE. RoPE encodes relative position along a single axis; cross-attention spans two axes (text Q positions vs image K/V positions) with no shared coordinate, so RoPE is dropped on both sides. The text axis already has RoPE applied inside each preceding TransformerBlock; cross-attention just queries off the resulting hidden state.

  • Output projection o_proj and the block’s MLP output projection are zero-initialized so the block starts as identity. This matches Llama-3-V’s warm-start: a CA arch added to an existing text-only checkpoint contributes zero gradient at step 0 and learns a non-zero contribution from there.

Classes

CrossAttention

Text Q × image K/V cross-attention with optional GQA.

CrossAttentionBlock

Pre-norm wrapper: CrossAttention + residual + MLP + residual.

class kempnerforge.model.cross_attention.CrossAttention[source]

Bases: Module

Text Q × image K/V cross-attention with optional GQA.

__init__(dim, n_heads, n_kv_heads, head_dim=None)[source]
Parameters:
  • dim (int)

  • n_heads (int)

  • n_kv_heads (int)

  • head_dim (int | None)

Return type:

None

forward(x, image_features, image_mask=None)[source]

Forward pass.

Parameters:
  • x (torch.Tensor) – Text hidden state, shape (batch, seq_len, dim).

  • image_features (torch.Tensor) – Image K/V source, shape (batch, num_image_tokens, dim).

  • image_mask (torch.Tensor | None) – Optional bool mask, shape (batch, num_image_tokens); True = attend, False = mask out. None = all image tokens attended to.

Returns:

Output tensor of shape (batch, seq_len, dim).

Return type:

torch.Tensor

class kempnerforge.model.cross_attention.CrossAttentionBlock[source]

Bases: Module

Pre-norm wrapper: CrossAttention + residual + MLP + residual.

Mirrors TransformerBlock’s outer shape so the freeze + FSDP + DCP plumbing applies uniformly. The MLP’s output projection is also zero-initialized so the whole block is identity at construction.

__init__(dim, n_heads, n_kv_heads, ffn_hidden_dim, norm_type='rmsnorm', activation='silu')[source]
Parameters:
  • dim (int)

  • n_heads (int)

  • n_kv_heads (int)

  • ffn_hidden_dim (int)

  • norm_type (str)

  • activation (str)

Return type:

None

forward(x, image_features, image_mask=None)[source]
Parameters:
Return type:

torch.Tensor