Skip to content

Batch TopK SAE

Batch TopK SAE is a variation of the standard Sparse Autoencoder (SAE) that enforces structured sparsity at the batch level using a global Top-K selection mechanism. Instead of selecting the K largest activations per sample, this method selects the top-K activations across the entire batch, ensuring a controlled level of sparsity.

The architecture follows the standard SAE framework, consisting of an encoder, a decoder, and a forward method:

  • encode returns the pre-codes (z_pre, before thresholding) and codes (z) given an input (x).
  • decode returns a reconstructed input (x_hat) based on an input (x).
  • forward returns the pre-codes, codes, and reconstructed input.

We strongly encourage you to check the original paper 1 to learn more about Batch TopK SAE.

Basic Usage

from overcomplete import BatchTopKSAE

# define a Batch TopK SAE with input dimension 768, 10k concepts
# and top_k = 50 (for the entire batch!)
sae = BatchTopKSAE(768, 10_000, top_k=50)

# the threshold momentum is used to estimate
# the final threshold (when in eval)
sae = BatchTopKSAE(768, 10_000, top_k=10, threshold_momentum=0.95)
# ... training sae
sae = sae.eval()
# now top_k is no longer use and instead an
# internal threshold is used
print(sae.running_threshold)

BatchTopKSAE

Batch Top-k Sparse SAE.

__init__(self,
         input_shape,
         nb_concepts,
         top_k=None,
         threshold_momentum=0.9,
         encoder_module=None,
         dictionary_params=None,
         device='cpu')

Parameters

  • input_shape : int or tuple of int

    • Dimensionality of the input data (excluding the batch dimension).

  • nb_concepts : int

    • Number of latent dimensions (components) of the autoencoder.

  • top_k : int

    • The number of activations to keep (the kth highest activation is used as threshold).

  • threshold_momentum : float, optional

    • Momentum for the running threshold update (default is 0.9).

  • encoder_module : nn.Module or str, optional

    • Custom encoder module (or its registered name). If None, a default encoder is used.

  • dictionary_params : dict, optional

    • Parameters that will be passed to the dictionary layer.

      See DictionaryLayer for more details.

  • device : str, optional

    • Device on which to run the model (default is 'cpu').

forward(self,
        x)

Perform a forward pass through the autoencoder.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • SAEOuput

    • Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).


encode(self,
       x)

Encode input data and apply global top-k thresholding.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • pre_codes : torch.Tensor

    • The raw outputs from the encoder.

  • z : torch.Tensor

    • The sparse latent representation after thresholding.


decode(self,
       z)

Decode latent representation to reconstruct input data.

Parameters

  • z : torch.Tensor

    • Latent representation tensor of shape (batch_size, nb_components).

Return

  • torch.Tensor

    • Reconstructed input tensor of shape (batch_size, input_size).


get_dictionary(self)

Return the learned dictionary.

Return

  • torch.Tensor

    • Learned dictionary tensor of shape (nb_components, input_size).



  1. Batch Top-k Sparse Autoencoders by Bussmann et al. (2024).