Relaxed Archetypal Dictionary¶
Archetypal SAE introduces a constraint on the dictionary where each atom is formed as a convex combination of data points with an additional relaxation term. This method enhances stability and interpretability in dictionary learning, making it a robust drop-in replacement for the dictionary layer in any Sparse Autoencoder. Simply initialize the archetypal dictionary and assign it to your SAE (e.g., sae.dictionary = archetypal_dict
).
Basic Usage¶
import torch
from overcomplete.sae.batchtopk_sae import BatchTopKSAE
from overcomplete.sae.archetypal_dictionary import RelaxedArchetypalDictionary
# initialize any sae
sae = BatchTopKSAE(768, 10_000, top_k=50)
# assume 'points' is a tensor of candidate data points (e.g., sampled from your dataset)
# the original paper recommend k-means
points = torch.randn(1000, 768)
# create our ra-sae
archetypal_dict = RelaxedArchetypalDictionary(
in_dimensions=768,
nb_concepts=10_000,
points=points,
delta=1.0,
)
# set the SAE's dictionary with the archetypal dictionary
sae.dictionary = archetypal_dict
# you can now train normally your sae
RelaxedArchetypalDictionary
¶
Dictionary used for Relaxed Archetypal SAE (RA-SAE).
__init__(self,
in_dimensions,
nb_concepts,
points,
delta=1.0,
use_multiplier=True,
device='cpu')
¶
in_dimensions,
nb_concepts,
points,
delta=1.0,
use_multiplier=True,
device='cpu')
Parameters
-
in_dimensions : int
Dimensionality of the input data (e.g number of channels).
-
nb_concepts : int
Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions.
-
points : tensors
Real data points (or point in the convex hull) used to find the candidates archetypes.
-
delta : float, optional
Constraint on the relaxation term, by default 1.0.
-
use_multiplier : bool, optional
Whether to train a positive scalar to multiply the dictionary after convex combination, making the dictionary in the conical hull (and not convex hull) of the data points, by default True.
-
device : str, optional
Device to run the model on ('cpu' or 'cuda'), by default 'cpu'.
forward(self,
z)
¶
z)
Reconstruct input data from latent representation.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, dimensions).
train(self,
mode=True)
¶
mode=True)
Hook called when switching between training and evaluation mode.
We use it to fuse W, C, Relax and multiplier into a single dictionary.
Parameters
-
mode : bool, optional
Whether to set the model to training mode or not, by default True.
get_dictionary(self)
¶
Get the dictionary.
Return
-
torch.Tensor
The dictionary tensor of shape (nb_components, dimensions).
-
Archetypal-SAE by Fel et al. (2025). ↩