Skip to content

Archetypal Sparse Autoencoders (TopK & Jump)

Archetypal SAEs combine the archetypal dictionary constraint with familiar sparse encoders: - RATopKSAE: TopK selection with archetypal atoms. - RAJumpSAE: JumpReLU selection with archetypal atoms.

Dictionary atoms stay close to convex combinations of provided data points (controlled by delta and an optional multiplier), stabilizing training and improving interpretability. This is the SAE-form of the Archetypal SAE idea.

Basic Usage

import torch
from overcomplete.sae import RATopKSAE, RAJumpSAE

points = torch.randn(2_000, 768)  # e.g. k-means centroids or sampled activations

ra_topk = RATopKSAE(
    input_shape=768,
    nb_concepts=10_000,
    points=points,
    top_k=20,
    delta=1.0,          # relaxation radius
    use_multiplier=True # learnable scaling of the archetypal hull
)

ra_jump = RAJumpSAE(
    input_shape=768,
    nb_concepts=10_000,
    points=points,
    bandwidth=1e-3,
    delta=1.5
)

Tips: - Provide reasonably diverse points (e.g., k-means cluster centers) for stable archetypes. - use_multiplier allows atoms to scale beyond the convex hull; set False to stay tighter. - All standard training utilities (train_sae, custom losses) work unchanged.

RATopKSAE

Relaxed Archetypal TopK SAE.

__init__(self,
         input_shape,
         nb_concepts,
         points,
         top_k=None,
         delta=1.0,
         use_multiplier=True,
         **kwargs)

Parameters

  • input_shape : int

    • Dimensionality of the input data (excluding the batch dimension).

  • nb_concepts : int

    • Number of dictionary atoms (concepts).

  • points : torch.Tensor

    • The data points used to initialize/define the archetypes.

      Shape should be (num_points, input_shape).

  • top_k : int

    • Number of top activations to keep in the latent representation.

      By default, 10% sparsity is used.

  • delta : float, optional

    • Delta parameter for the archetypal dictionary, by default 1.0.

  • use_multiplier : bool, optional

    • Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter is 3 then the dictionary atoms are all on the ball of radius 3). By default True.

  • kwargs** : dict, optional

    • Additional arguments passed to the parent TopKSAE (e.g., encoder_module, device).

encode(self,
       x)

Encode input data to latent representation.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • pre_codes : torch.Tensor

    • Pre-codes tensor of shape (batch_size, nb_components) before the relu and top-k operation.

  • z : torch.Tensor

    • Codes, latent representation tensor (z) of shape (batch_size, nb_components).


RAJumpSAE

Relaxed Archetypal Jump SAE.

__init__(self,
         input_shape,
         nb_concepts,
         points,
         bandwidth=0.001,
         delta=1.0,
         use_multiplier=True,
         **kwargs)

Parameters

  • input_shape : int

    • Dimensionality of the input data (excluding the batch dimension).

  • nb_concepts : int

    • Number of dictionary atoms (concepts).

  • points : torch.Tensor

    • The data points used to initialize/define the archetypes.

      Shape should be (num_points, input_shape).

  • bandwidth : float, optional

    • Bandwidth parameter for the Jump SAE kernel, by default 1e-3.

  • delta : float, optional

    • Delta parameter for the archetypal dictionary, by default 1.0.

  • use_multiplier : bool, optional

    • Whether to use a learnable multiplier that parametrize the ball (e.g. if this parameter is 3 then the dictionary atoms are all on the ball of radius 3). By default True.

  • kwargs** : dict, optional

    • Additional arguments passed to the parent JumpSAE (e.g., encoder_module, device).

encode(self,
       x)

Encode input data to latent representation.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • pre_codes : torch.Tensor

    • Pre-codes tensor of shape (batch_size, nb_components) before the jump operation.

  • codes : torch.Tensor

    • Codes, latent representation tensor (z) of shape (batch_size, nb_components).