Batch TopK SAE¶
Batch TopK SAE is a variation of the standard Sparse Autoencoder (SAE) that enforces structured sparsity at the batch level using a global Top-K selection mechanism. Instead of selecting the K largest activations per sample, this method selects the top-K activations across the entire batch, ensuring a controlled level of sparsity.
The architecture follows the standard SAE framework, consisting of an encoder, a decoder, and a forward method:
encodereturns the pre-codes (z_pre, before thresholding) and codes (z) given an input (x).decodereturns a reconstructed input (x_hat) based on an input (x).forwardreturns the pre-codes, codes, and reconstructed input.
We strongly encourage you to check the original paper 1 to learn more about Batch TopK SAE.
Basic Usage¶
from overcomplete import BatchTopKSAE
# define a Batch TopK SAE with input dimension 768, 10k concepts
# and top_k = 50 (for the entire batch!)
sae = BatchTopKSAE(768, 10_000, top_k=50)
# the threshold momentum is used to estimate
# the final threshold (when in eval)
sae = BatchTopKSAE(768, 10_000, top_k=10, threshold_momentum=0.95)
# ... training sae
sae = sae.eval()
# now top_k is no longer use and instead an
# internal threshold is used
print(sae.running_threshold)
BatchTopKSAE¶
Batch Top-k Sparse SAE.
__init__(self,
input_shape,
nb_concepts,
top_k=None,
threshold_momentum=0.9,
encoder_module=None,
dictionary_params=None,
device='cpu')¶
input_shape,
nb_concepts,
top_k=None,
threshold_momentum=0.9,
encoder_module=None,
dictionary_params=None,
device='cpu')
Parameters
-
input_shape : int or tuple of int
Dimensionality of the input data (excluding the batch dimension).
-
nb_concepts : int
Number of latent dimensions (components) of the autoencoder.
-
top_k : int
The number of activations to keep (the kth highest activation is used as threshold).
-
threshold_momentum : float, optional
Momentum for the running threshold update (default is 0.9).
-
encoder_module : nn.Module or str, optional
Custom encoder module (or its registered name). If None, a default encoder is used.
-
dictionary_params : dict, optional
Parameters that will be passed to the dictionary layer.
See DictionaryLayer for more details.
-
device : str, optional
Device on which to run the model (default is 'cpu').
decode(self,
z)¶
z)
Decode latent representation to reconstruct input data.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, input_size).
tied(self,
bias=False)¶
bias=False)
Tie encoder weights to dictionary (use D^T as encoder).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
Return
-
self
Returns self for method chaining.
untied(self,
bias=False,
copy_from_dictionary=True)¶
bias=False,
copy_from_dictionary=True)
Create a new encoder with weight from the current dictionary (or random init).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
-
copy_from_dictionary : bool, optional
If True, initialize encoder with current dictionary weights, by default True.
Return
-
self
Returns self for method chaining.
get_dictionary(self)¶
Return the learned dictionary.
Return
-
torch.Tensor
Learned dictionary tensor of shape (nb_components, input_size).
forward(self,
x)¶
x)
Perform a forward pass through the autoencoder.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
SAEOuput
Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).
encode(self,
x)¶
x)
Encode input data and apply global top-k thresholding.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
pre_codes : torch.Tensor
The raw outputs from the encoder.
-
z : torch.Tensor
The sparse latent representation after thresholding.
-
Batch Top-k Sparse Autoencoders by Bussmann et al. (2024). ↩