kempnerforge.model.attention¶
Multi-head attention with Grouped-Query Attention (GQA) support.
- GQA is the general case:
n_kv_heads == n_heads → standard Multi-Head Attention (MHA)
n_kv_heads == 1 → Multi-Query Attention (MQA)
1 < n_kv_heads < n_heads → Grouped-Query Attention (GQA)
Classes
Grouped-Query Attention with RoPE and SDPA. |
|
Pre-allocated KV cache for autoregressive generation. |
- class kempnerforge.model.attention.KVCache[source]¶
Bases:
objectPre-allocated KV cache for autoregressive generation.
Stores key and value tensors for all previous positions, enabling incremental decoding without recomputing attention over the full sequence. Keys are stored after RoPE application but before GQA expansion.
- __init__(batch_size, max_seq_len, n_kv_heads, head_dim, dtype, device)[source]¶
- Parameters:
batch_size (int)
max_seq_len (int)
n_kv_heads (int)
head_dim (int)
dtype (torch.dtype)
device (torch.device)
- Return type:
None
- update(k_new, v_new)[source]¶
Append new key/value entries and return full cached tensors.
- Parameters:
k_new (torch.Tensor) – New keys, shape (batch, n_kv_heads, new_seq_len, head_dim).
v_new (torch.Tensor) – New values, shape (batch, n_kv_heads, new_seq_len, head_dim).
- Returns:
Tuple of (all_keys, all_values), each (batch, n_kv_heads, total_seq_len, head_dim).
- Return type:
- class kempnerforge.model.attention.Attention[source]¶
Bases:
ModuleGrouped-Query Attention with RoPE and SDPA.
- forward(x, rope_cos, rope_sin, *, kv_cache=None, doc_ids=None)[source]¶
Forward pass.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch, seq_len, dim).
rope_cos (torch.Tensor) – RoPE cosine frequencies, shape (seq_len, head_dim // 2).
rope_sin (torch.Tensor) – RoPE sine frequencies, shape (seq_len, head_dim // 2).
kv_cache (KVCache | None) – Optional KV cache for incremental generation.
doc_ids (torch.Tensor | None) – Optional per-token document IDs for packed sequences, shape (batch, seq_len). When provided, constructs a block-diagonal causal mask so tokens only attend within their document.
- Returns:
Output tensor of shape (batch, seq_len, dim).
- Return type: