Skip to content

Orthogonal Matching Pursuit SAE (OMPSAE)

OMPSAE uses orthogonal matching pursuit for sparse coding. Each iteration picks the atom most correlated with the current residual, then resolves NNLS on all selected atoms to refine the codes. This tighter refit often improves reconstruction over plain matching pursuit. For background, see Sparse Autoencoders via Matching Pursuit.

Basic Usage

import torch
from overcomplete.sae import OMPSAE

x = torch.randn(64, 512)
sae = OMPSAE(
    input_shape=512,
    nb_concepts=4_096,
    k=4,            # pursuit steps
    max_iter=15,    # NNLS iterations
    dropout=0.1,    # optional atom dropout
    encoder_module="identity",
    device="cuda"
)

residual, codes = sae.encode(x)

Notes: - encode returns (residual, codes); residual is the reconstruction error after pursuit steps. - Set dropout to randomly mask atoms each iteration. - Inputs must be 1D features (no 3D/4D tensors); k and max_iter must be positive.

OMPSAE

Orthogonal Matching Pursuit Sparse Autoencoder (OMPSAE).

__init__(self,
         input_shape,
         nb_concepts,
         k=1,
         dropout=None,
         encoder_module='identity',
         dictionary_params=None,
         device='cpu',
         max_iter=10)

Parameters

  • input_shape : int or tuple of int

    • Dimensionality of the input data (excluding batch dimension).

  • nb_concepts : int

    • Number of latent components (atoms) in the dictionary.

  • k : int, optional

    • Default number of pursuit iterations (must be > 0).

  • dropout : float, optional

    • Dropout rate applied to dictionary atoms (range [0.0, 1.0]).

  • encoder_module : str or nn.Module, optional

    • Encoder module or name of registered encoder.

  • dictionary_params : dict, optional

    • Parameters passed to the dictionary layer.

  • device : str, optional

    • Device to run the model on (default is 'cpu').

  • max_iter : int, optional

    • Default number of NNLS iterations (default: 10).

decode(self,
       z)

Decode latent representation to reconstruct input data.

Parameters

  • z : torch.Tensor

    • Latent representation tensor of shape (batch_size, nb_components).

Return

  • torch.Tensor

    • Reconstructed input tensor of shape (batch_size, input_size).


tied(self,
     bias=False)

Tie encoder weights to dictionary (use D^T as encoder).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

Return

  • self

    • Returns self for method chaining.


untied(self,
       bias=False,
       copy_from_dictionary=True)

Create a new encoder with weight from the current dictionary (or random init).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

  • copy_from_dictionary : bool, optional

    • If True, initialize encoder with current dictionary weights, by default True.

Return

  • self

    • Returns self for method chaining.


get_dictionary(self)

Return the learned dictionary.

Return

  • torch.Tensor

    • Learned dictionary tensor of shape (nb_components, input_size).


forward(self,
        x)

Perform a forward pass through the autoencoder.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • SAEOuput

    • Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).


train(self,
      mode=True)

Hook called when switching between training and evaluation mode. We use it to ensure no dropout is applied during evaluation.

Parameters

  • mode : bool, optional

    • Whether to set the model to training mode or not, by default True.


encode(self,
       x,
       k=None,
       max_iter=None)

Encode input using Orthogonal Matching Pursuit.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_dim).

  • k : int, optional

    • Override the number of pursuit iterations.

  • max_iter : int, optional

    • Override the number of NNLS iterations.

Return

  • residual : torch.Tensor

    • Final residual after k iterations.

  • codes : torch.Tensor

    • Sparse codes of shape (batch_size, nb_concepts).