JumpReLU SAE¶
JumpReLU SAE is a variation of the standard Sparse Autoencoder (SAE) that incorporates the JumpReLU activation function to have an adaptive sparsity without shriking. This involve a learnable thresholding mechanism on each concept. As all SAEs, it include an encoder, a decoder, and a forward method.
encodereturns the pre-codes (z_pre, before ReLU) and codes (z) given an input (x).decodereturns a reconstructed input (x_hat) based on an input (x).forwardreturns the pre-codes, codes, and reconstructed input.
kernel='silverman', bandwith=1e-3,
The specificity of this architecture is that i contains 2 hyperparamter, a bandwith and a kernel. We strongly encourage you to check the original paper 1 to know more about JumpReLU.
Basic Usage¶
from overcomplete import JumpSAE
# define a JumpReLU SAE with input dimension 768 and 10k concepts
sae = JumpSAE(768, 10_000)
# adjust kernel and bandwith
sae = JumpSAE(768, 10_000, bandwith = 1e-2,
kernel='silverman')
JumpSAE¶
JumpReLU Sparse Autoencoder (SAE).
__init__(self,
input_shape,
nb_concepts,
kernel='silverman',
bandwidth=0.001,
encoder_module=None,
dictionary_params=None,
device='cpu')¶
input_shape,
nb_concepts,
kernel='silverman',
bandwidth=0.001,
encoder_module=None,
dictionary_params=None,
device='cpu')
Parameters
-
input_shape : int or tuple of int
Dimensionality of the input data, do not include batch dimensions.
It is usually 1d (dim), 2d (seq length, dim) or 3d (dim, height, width).
-
nb_concepts : int
Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions.
-
kernel : str, optional
Kernel function to use in the JumpReLU activation, by default 'silverman'.
Current options are : - 'rectangle' - 'gaussian' - 'triangular' - 'cosine' - 'epanechnikov' - 'quartic' - 'silverman' - 'cauchy'.
-
bandwidth : float, optional
Bandwidth of the kernel, by default 1e-3.
-
encoder_module : nn.Module or string, optional
Custom encoder module, by default None.
If None, a simple Linear + BatchNorm default encoder is used.
If string, the name of the registered encoder module.
-
dictionary_params : dict, optional
Parameters that will be passed to the dictionary layer.
See DictionaryLayer for more details.
-
device : str, optional
Device to run the model on, by default 'cpu'.
decode(self,
z)¶
z)
Decode latent representation to reconstruct input data.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, input_size).
tied(self,
bias=False)¶
bias=False)
Tie encoder weights to dictionary (use D^T as encoder).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
Return
-
self
Returns self for method chaining.
untied(self,
bias=False,
copy_from_dictionary=True)¶
bias=False,
copy_from_dictionary=True)
Create a new encoder with weight from the current dictionary (or random init).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
-
copy_from_dictionary : bool, optional
If True, initialize encoder with current dictionary weights, by default True.
Return
-
self
Returns self for method chaining.
get_dictionary(self)¶
Return the learned dictionary.
Return
-
torch.Tensor
Learned dictionary tensor of shape (nb_components, input_size).
forward(self,
x)¶
x)
Perform a forward pass through the autoencoder.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
SAEOuput
Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).
encode(self,
x)¶
x)
Encode input data to latent representation.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
pre_codes : torch.Tensor
Pre-codes tensor of shape (batch_size, nb_components) before the jump operation.
-
codes : torch.Tensor
Codes, latent representation tensor (z) of shape (batch_size, nb_components).
-
Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders (2024) by Rajamanoharan et al. (2024). ↩