Archetypal Sparse Autoencoders (TopK & Jump)¶
Archetypal SAEs combine the archetypal dictionary constraint with familiar sparse encoders: - RATopKSAE: TopK selection with archetypal atoms. - RAJumpSAE: JumpReLU selection with archetypal atoms.
Dictionary atoms stay close to convex combinations of provided data points (controlled by delta and an optional multiplier), stabilizing training and improving interpretability. This is the SAE-form of the Archetypal SAE idea.
Basic Usage¶
import torch
from overcomplete.sae import RATopKSAE, RAJumpSAE
points = torch.randn(2_000, 768) # e.g. k-means centroids or sampled activations
ra_topk = RATopKSAE(
input_shape=768,
nb_concepts=10_000,
points=points,
top_k=20,
delta=1.0, # relaxation radius
use_multiplier=True # learnable scaling of the archetypal hull
)
ra_jump = RAJumpSAE(
input_shape=768,
nb_concepts=10_000,
points=points,
bandwidth=1e-3,
delta=1.5
)
Tips:
- Provide reasonably diverse points (e.g., k-means cluster centers) for stable archetypes.
- use_multiplier allows atoms to scale beyond the convex hull; set False to stay tighter.
- All standard training utilities (train_sae, custom losses) work unchanged.
RATopKSAE¶
Relaxed Archetypal TopK SAE.
__init__(self,
input_shape,
nb_concepts,
points,
top_k=None,
delta=1.0,
use_multiplier=True,
**kwargs)¶
input_shape,
nb_concepts,
points,
top_k=None,
delta=1.0,
use_multiplier=True,
**kwargs)
Parameters
-
input_shape : int
Dimensionality of the input data (excluding the batch dimension).
-
nb_concepts : int
Number of dictionary atoms (concepts).
-
points : torch.Tensor
The data points used to initialize/define the archetypes.
Shape should be (num_points, input_shape).
-
top_k : int
Number of top activations to keep in the latent representation.
By default, 10% sparsity is used.
-
delta : float, optional
Delta parameter for the archetypal dictionary, by default 1.0.
-
use_multiplier : bool, optional
Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter is 3 then the dictionary atoms are all on the ball of radius 3). By default True.
-
kwargs** : dict, optional
Additional arguments passed to the parent TopKSAE (e.g., encoder_module, device).
encode(self,
x)¶
x)
Encode input data to latent representation.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
pre_codes : torch.Tensor
Pre-codes tensor of shape (batch_size, nb_components) before the relu and top-k operation.
-
z : torch.Tensor
Codes, latent representation tensor (z) of shape (batch_size, nb_components).
RAJumpSAE¶
Relaxed Archetypal Jump SAE.
__init__(self,
input_shape,
nb_concepts,
points,
bandwidth=0.001,
delta=1.0,
use_multiplier=True,
**kwargs)¶
input_shape,
nb_concepts,
points,
bandwidth=0.001,
delta=1.0,
use_multiplier=True,
**kwargs)
Parameters
-
input_shape : int
Dimensionality of the input data (excluding the batch dimension).
-
nb_concepts : int
Number of dictionary atoms (concepts).
-
points : torch.Tensor
The data points used to initialize/define the archetypes.
Shape should be (num_points, input_shape).
-
bandwidth : float, optional
Bandwidth parameter for the Jump SAE kernel, by default 1e-3.
-
delta : float, optional
Delta parameter for the archetypal dictionary, by default 1.0.
-
use_multiplier : bool, optional
Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter is 3 then the dictionary atoms are all on the ball of radius 3). By default True.
-
kwargs** : dict, optional
Additional arguments passed to the parent JumpSAE (e.g., encoder_module, device).
encode(self,
x)¶
x)
Encode input data to latent representation.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
pre_codes : torch.Tensor
Pre-codes tensor of shape (batch_size, nb_components) before the jump operation.
-
codes : torch.Tensor
Codes, latent representation tensor (z) of shape (batch_size, nb_components).