Vanilla SAE¶
The most basic SAE. It consists of an encoder and a decoder. The decoder is simply a dictionary, while the encoder can be configured. By default, it is a linear module with bias and ReLU activation. All SAEs include an encoder, a decoder, and a forward method.
- encodereturns the pre-codes (- z_pre, before ReLU) and codes (- z) given an input (- x).
- decodereturns a reconstructed input (- x_hat) based on an input (- x).
- forwardreturns the pre-codes, codes, and reconstructed input.
Basic Usage¶
from overcomplete import SAE
# Define a basic SAE where input dimension is 768, with 10k concepts
# Using a simple linear encoding
sae = SAE(768, 10_000)
# Define a more complex SAE with batch normalization in the encoder
# The dictionary is normalized on the L1 ball instead of L2
sae = SAE(768, 10_000, encoder_module='mlp_bn_1',
          dictionary_params={'normalization': 'l1'})
SAE¶
Sparse Autoencoder (SAE) model for dictionary learning.
__init__(self,
         input_shape,
         nb_concepts,
         encoder_module=None,
         dictionary_params=None,
         device='cpu')¶
input_shape,
nb_concepts,
encoder_module=None,
dictionary_params=None,
device='cpu')
Parameters
- 
input_shape : int or tuple of int - Dimensionality of the input data, do not include batch dimensions. - It is usually 1d (dim), 2d (seq length, dim) or 3d (dim, height, width). 
 
- 
nb_concepts : int - Number of components/concepts in the dictionary. The dictionary is overcomplete if the number of concepts > in_dimensions. 
 
- 
encoder_module : nn.Module or string, optional - Custom encoder module, by default None. - If None, a simple Linear + BatchNorm default encoder is used. - If string, the name of the registered encoder module. 
 
- 
dictionary_params : dict, optional - Parameters that will be passed to the dictionary layer. - See DictionaryLayer for more details. 
 
- 
device : str, optional - Device to run the model on, by default 'cpu'. 
 
forward(self,
        x)¶
x)
Perform a forward pass through the autoencoder.
Parameters
- 
x : torch.Tensor - Input tensor of shape (batch_size, input_size). 
 
Return
- 
SAEOuput - Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat). 
 
encode(self,
       x)¶
x)
Encode input data to latent representation.
Parameters
- 
x : torch.Tensor - Input tensor of shape (batch_size, input_size). 
 
Return
- 
pre_codes : torch.Tensor - Pre-codes tensor of shape (batch_size, nb_components), before the activation function. 
 
- 
codes : torch.Tensor - Codes, latent representation tensor (z) of shape (batch_size, nb_components). 
 
decode(self,
       z)¶
z)
Decode latent representation to reconstruct input data.
Parameters
- 
z : torch.Tensor - Latent representation tensor of shape (batch_size, nb_components). 
 
Return
- 
torch.Tensor - Reconstructed input tensor of shape (batch_size, input_size). 
 
get_dictionary(self)¶
Return the learned dictionary.
Return
- 
torch.Tensor - Learned dictionary tensor of shape (nb_components, input_size).