Relaxed Archetypal Dictionary¶
Archetypal SAE introduces a constraint on the dictionary where each atom is formed as a convex combination of data points with an additional relaxation term. This method enhances stability and interpretability in dictionary learning, making it a robust drop-in replacement for the dictionary layer in any Sparse Autoencoder. Simply initialize the archetypal dictionary and assign it to your SAE (e.g., sae.dictionary = archetypal_dict).
Basic Usage¶
import torch
from overcomplete.sae.batchtopk_sae import BatchTopKSAE
from overcomplete.sae.archetypal_dictionary import RelaxedArchetypalDictionary
# initialize any sae
sae = BatchTopKSAE(768, 10_000, top_k=50)
# assume 'points' is a tensor of candidate data points (e.g., sampled from your dataset)
# the original paper recommend k-means
points = torch.randn(1000, 768)
# create our ra-sae
archetypal_dict = RelaxedArchetypalDictionary(
    in_dimensions=768,
    nb_concepts=10_000,
    points=points,
    delta=1.0,
)
# set the SAE's dictionary with the archetypal dictionary
sae.dictionary = archetypal_dict
# you can now train normally your sae
RelaxedArchetypalDictionary¶
Dictionary used for Relaxed Archetypal SAE (RA-SAE).
__init__(self,
         in_dimensions,
         nb_concepts,
         points,
         delta=1.0,
         use_multiplier=True,
         device='cpu')¶
in_dimensions,
nb_concepts,
points,
delta=1.0,
use_multiplier=True,
device='cpu')
Parameters
- 
in_dimensions : int - Dimensionality of the input data (e.g number of channels). 
 
- 
nb_concepts : int - Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions. 
 
- 
points : tensors - Real data points (or point in the convex hull) used to find the candidates archetypes. 
 
- 
delta : float, optional - Constraint on the relaxation term, by default 1.0. 
 
- 
use_multiplier : bool, optional - Whether to train a positive scalar to multiply the dictionary after convex combination, making the dictionary in the conical hull (and not convex hull) of the data points, by default True. 
 
- 
device : str, optional - Device to run the model on ('cpu' or 'cuda'), by default 'cpu'. 
 
forward(self,
        z)¶
z)
Reconstruct input data from latent representation.
Parameters
- 
z : torch.Tensor - Latent representation tensor of shape (batch_size, nb_components). 
 
Return
- 
torch.Tensor - Reconstructed input tensor of shape (batch_size, dimensions). 
 
train(self,
      mode=True)¶
mode=True)
Hook called when switching between training and evaluation mode.
We use it to fuse W, C, Relax and multiplier into a single dictionary.
Parameters
- 
mode : bool, optional - Whether to set the model to training mode or not, by default True. 
 
get_dictionary(self)¶
Get the dictionary.
Return
- 
torch.Tensor - The dictionary tensor of shape (nb_components, dimensions). 
 
- 
Archetypal-SAE by Fel et al. (2025). ↩