TopK SAE¶
TopK SAE is a variation of the standard Sparse Autoencoder (SAE) that enforces structured sparsity using a Top-K selection mechanism. This method ensures that only the K most significant activations are retained in the encoded representation, promoting interpretability and feature selection.
The architecture follows the standard SAE framework, consisting of an encoder, a decoder, and a forward method:
encode
returns the pre-codes (z_pre
, before activation) and codes (z
) given an input (x
).decode
returns a reconstructed input (x_hat
) based on an input (x
).forward
returns the pre-codes, codes, and reconstructed input.
We strongly encourage you to check the original paper 1 to learn more about TopK SAE.
Basic Usage¶
from overcomplete import TopKSAE
# define a TopK SAE with input dimension 768, 10k concepts
sae = TopKSAE(768, 10_000, top_k=5)
# adjust the encoder module (you can also)
# directly pass your own encoder module
sae = TopKSAE(768, 10_000, top_k=10,
encoder_module='mlp_bn_1')
TopKSAE
¶
Top-k Sparse SAE.
__init__(self,
input_shape,
nb_concepts,
top_k=None,
encoder_module=None,
dictionary_params=None,
device='cpu')
¶
input_shape,
nb_concepts,
top_k=None,
encoder_module=None,
dictionary_params=None,
device='cpu')
Parameters
-
input_shape : int or tuple of int
Dimensionality of the input data, do not include batch dimensions.
It is usually 1d (dim), 2d (seq length, dim) or 3d (dim, height, width).
-
nb_concepts : int
Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions.
-
top_k : int, optional
Number of top activations to keep in the latent representation, by default n_components // 10 (sparsity of 90%).
-
encoder_module : nn.Module or string, optional
Custom encoder module, by default None.
If None, a simple Linear + BatchNorm default encoder is used.
If string, the name of the registered encoder module.
-
dictionary_params : dict, optional
Parameters that will be passed to the dictionary layer.
See DictionaryLayer for more details.
-
device : str, optional
Device to run the model on, by default 'cpu'.
forward(self,
x)
¶
x)
Perform a forward pass through the autoencoder.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
SAEOuput
Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).
encode(self,
x)
¶
x)
Encode input data to latent representation.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
pre_codes : torch.Tensor
Pre-codes tensor of shape (batch_size, nb_components) before the relu and top-k operation.
-
z : torch.Tensor
Codes, latent representation tensor (z) of shape (batch_size, nb_components).
decode(self,
z)
¶
z)
Decode latent representation to reconstruct input data.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, input_size).
get_dictionary(self)
¶
Return the learned dictionary.
Return
-
torch.Tensor
Learned dictionary tensor of shape (nb_components, input_size).
-
Scaling and Evaluating Sparse Autoencoders by Gao et al. (2024). ↩