Skip to content

JumpReLU SAE

JumpReLU SAE is a variation of the standard Sparse Autoencoder (SAE) that incorporates the JumpReLU activation function to have an adaptive sparsity without shriking. This involve a learnable thresholding mechanism on each concept. As all SAEs, it include an encoder, a decoder, and a forward method.

  • encode returns the pre-codes (z_pre, before ReLU) and codes (z) given an input (x).
  • decode returns a reconstructed input (x_hat) based on an input (x).
  • forward returns the pre-codes, codes, and reconstructed input.

kernel='silverman', bandwith=1e-3,

The specificity of this architecture is that i contains 2 hyperparamter, a bandwith and a kernel. We strongly encourage you to check the original paper 1 to know more about JumpReLU.

Basic Usage

from overcomplete import JumpSAE

# define a JumpReLU SAE with input dimension 768 and 10k concepts
sae = JumpSAE(768, 10_000)

# adjust kernel and bandwith
sae = JumpSAE(768, 10_000, bandwith = 1e-2,
              kernel='silverman')

JumpSAE

JumpReLU Sparse Autoencoder (SAE).

__init__(self,
         input_shape,
         nb_concepts,
         kernel='silverman',
         bandwidth=0.001,
         encoder_module=None,
         dictionary_params=None,
         device='cpu')

Parameters

  • input_shape : int or tuple of int

    • Dimensionality of the input data, do not include batch dimensions.

      It is usually 1d (dim), 2d (seq length, dim) or 3d (dim, height, width).

  • nb_concepts : int

    • Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions.

  • kernel : str, optional

    • Kernel function to use in the JumpReLU activation, by default 'silverman'.

      Current options are : - 'rectangle' - 'gaussian' - 'triangular' - 'cosine' - 'epanechnikov' - 'quartic' - 'silverman' - 'cauchy'.

  • bandwidth : float, optional

    • Bandwidth of the kernel, by default 1e-3.

  • encoder_module : nn.Module or string, optional

    • Custom encoder module, by default None.

      If None, a simple Linear + BatchNorm default encoder is used.

      If string, the name of the registered encoder module.

  • dictionary_params : dict, optional

    • Parameters that will be passed to the dictionary layer.

      See DictionaryLayer for more details.

  • device : str, optional

    • Device to run the model on, by default 'cpu'.

decode(self,
       z)

Decode latent representation to reconstruct input data.

Parameters

  • z : torch.Tensor

    • Latent representation tensor of shape (batch_size, nb_components).

Return

  • torch.Tensor

    • Reconstructed input tensor of shape (batch_size, input_size).


tied(self,
     bias=False)

Tie encoder weights to dictionary (use D^T as encoder).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

Return

  • self

    • Returns self for method chaining.


untied(self,
       bias=False,
       copy_from_dictionary=True)

Create a new encoder with weight from the current dictionary (or random init).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

  • copy_from_dictionary : bool, optional

    • If True, initialize encoder with current dictionary weights, by default True.

Return

  • self

    • Returns self for method chaining.


get_dictionary(self)

Return the learned dictionary.

Return

  • torch.Tensor

    • Learned dictionary tensor of shape (nb_components, input_size).


forward(self,
        x)

Perform a forward pass through the autoencoder.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • SAEOuput

    • Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).


encode(self,
       x)

Encode input data to latent representation.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • pre_codes : torch.Tensor

    • Pre-codes tensor of shape (batch_size, nb_components) before the jump operation.

  • codes : torch.Tensor

    • Codes, latent representation tensor (z) of shape (batch_size, nb_components).