Skip to content

Relaxed Archetypal Dictionary

Archetypal SAE introduces a constraint on the dictionary where each atom is formed as a convex combination of data points with an additional relaxation term. This method enhances stability and interpretability in dictionary learning, making it a robust drop-in replacement for the dictionary layer in any Sparse Autoencoder. Simply initialize the archetypal dictionary and assign it to your SAE (e.g., sae.dictionary = archetypal_dict).

Basic Usage

import torch
from overcomplete.sae.batchtopk_sae import BatchTopKSAE
from overcomplete.sae.archetypal_dictionary import RelaxedArchetypalDictionary

# initialize any sae
sae = BatchTopKSAE(768, 10_000, top_k=50)

# assume 'points' is a tensor of candidate data points (e.g., sampled from your dataset)
# the original paper recommend k-means
points = torch.randn(1000, 768)

# create our ra-sae
archetypal_dict = RelaxedArchetypalDictionary(
    in_dimensions=768,
    nb_concepts=10_000,
    points=points,
    delta=1.0,
)

# set the SAE's dictionary with the archetypal dictionary
sae.dictionary = archetypal_dict
# you can now train normally your sae

RelaxedArchetypalDictionary

Dictionary used for Relaxed Archetypal SAE (RA-SAE).

__init__(self,
         in_dimensions,
         nb_concepts,
         points,
         delta=1.0,
         use_multiplier=True,
         device='cpu')

Parameters

  • in_dimensions : int

    • Dimensionality of the input data (e.g number of channels).

  • nb_concepts : int

    • Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions.

  • points : tensors

    • Real data points (or point in the convex hull) used to find the candidates archetypes.

  • delta : float, optional

    • Constraint on the relaxation term, by default 1.0.

  • use_multiplier : bool, optional

    • Whether to train a positive scalar to multiply the dictionary after convex combination, making the dictionary in the conical hull (and not convex hull) of the data points, by default True.

  • device : str, optional

    • Device to run the model on ('cpu' or 'cuda'), by default 'cpu'.

forward(self,
        z)

Reconstruct input data from latent representation.

Parameters

  • z : torch.Tensor

    • Latent representation tensor of shape (batch_size, nb_components).

Return

  • torch.Tensor

    • Reconstructed input tensor of shape (batch_size, dimensions).


train(self,
      mode=True)

Hook called when switching between training and evaluation mode. We use it to fuse W, C, Relax and multiplier into a single dictionary.

Parameters

  • mode : bool, optional

    • Whether to set the model to training mode or not, by default True.


get_dictionary(self)

Get the dictionary.

Return

  • torch.Tensor

    • The dictionary tensor of shape (nb_components, dimensions).



  1. Archetypal-SAE by Fel et al. (2025).