kempnerforge.model.cross_attention¶
Cross-attention block for VLM Cross-Attention architecture.
Text queries cross-attend to image keys/values. Inserted between
regular TransformerBlock``s at a configurable cadence; the residual
stream itself carries text only, with image K/V flowing in as a side
channel via ``ModalityContext.image_features.
Differences from Attention:
No causal mask on the image axis. Image tokens have no temporal order that aligns with text positions, so the full image K/V set is visible from every text position.
No RoPE. RoPE encodes relative position along a single axis; cross-attention spans two axes (text Q positions vs image K/V positions) with no shared coordinate, so RoPE is dropped on both sides. The text axis already has RoPE applied inside each preceding
TransformerBlock; cross-attention just queries off the resulting hidden state.Output projection
o_projand the block’s MLP output projection are zero-initialized so the block starts as identity. This matches Llama-3-V’s warm-start: a CA arch added to an existing text-only checkpoint contributes zero gradient at step 0 and learns a non-zero contribution from there.
Classes
Text Q × image K/V cross-attention with optional GQA. |
|
Pre-norm wrapper: |
- class kempnerforge.model.cross_attention.CrossAttention[source]¶
Bases:
ModuleText Q × image K/V cross-attention with optional GQA.
- forward(x, image_features, image_mask=None)[source]¶
Forward pass.
- Parameters:
x (torch.Tensor) – Text hidden state, shape
(batch, seq_len, dim).image_features (torch.Tensor) – Image K/V source, shape
(batch, num_image_tokens, dim).image_mask (torch.Tensor | None) – Optional bool mask, shape
(batch, num_image_tokens);True= attend,False= mask out.None= all image tokens attended to.
- Returns:
Output tensor of shape
(batch, seq_len, dim).- Return type:
- class kempnerforge.model.cross_attention.CrossAttentionBlock[source]¶
Bases:
ModulePre-norm wrapper:
CrossAttention+ residual + MLP + residual.Mirrors
TransformerBlock’s outer shape so the freeze + FSDP + DCP plumbing applies uniformly. The MLP’s output projection is also zero-initialized so the whole block is identity at construction.- forward(x, image_features, image_mask=None)[source]¶
- Parameters:
x (torch.Tensor)
image_features (torch.Tensor)
image_mask (torch.Tensor | None)
- Return type: