Skip to content

Matching Pursuit SAE (MpSAE)

MpSAE replaces thresholding with a greedy matching pursuit loop: at each step it picks the atom most correlated with the residual, updates the codes, and subtracts the atom’s contribution, yielding sparse codes that track reconstruction progress. We encourage reading 1 for the full method.

Basic Usage

from overcomplete import MpSAE

# define a Matching Pursuit SAE with input dimension 512, 4k concepts
sae = MpSAE(512, 4_096, k=4, dropout=0.1)

# k = number of pursuit steps, dropout optionally masks atoms each step
residual, codes = sae.encode(x)

Advanced: auxiliary loss to revive dead codes

To ensure high dictionary utilization in MP-SAE, we strongly recommend implementing an auxiliary loss term. Here is an example of such loss:

def criterion(x, x_hat, residual, z, d):
    recon_loss = ((x - x_hat) ** 2).mean()

    revive_mask = (z.amax(dim=0) < 1e-2).detach()  # shape: [c]

    if revive_mask.sum() > 10:
        projected = residual @ d.T  # shape: [n, c]
        revive_term = projected[:, revive_mask].mean()
        recon_loss -= revive_term * 1e-2

    return recon_loss

MpSAE

Matching Pursuit Sparse Autoencoder (MpSAE).

__init__(self,
         input_shape,
         nb_concepts,
         k=1,
         dropout=None,
         encoder_module='identity',
         dictionary_params=None,
         device='cpu')

Parameters

  • input_shape : int or tuple of int

    • Dimensionality of the input data (excluding the batch dimension).

  • nb_concepts : int

    • Number of latent dimensions (components) of the autoencoder.

  • k : int, optional

    • The number of matching pursuit iterations to perform (must be > 0).

  • dropout : float, optional

    • Probability of dropping a dictionary element at each iteration (range: 0.0 - 1.0). If None, no dropout is applied.

  • encoder_module : nn.Module or str, optional

    • Custom encoder module (or its registered name). If None, a default encoder is used.

  • dictionary_params : dict, optional

    • Parameters that will be passed to the dictionary layer.

      See DictionaryLayer for more details.

  • device : str, optional

    • Device on which to run the model (default is 'cpu').

decode(self,
       z)

Decode latent representation to reconstruct input data.

Parameters

  • z : torch.Tensor

    • Latent representation tensor of shape (batch_size, nb_components).

Return

  • torch.Tensor

    • Reconstructed input tensor of shape (batch_size, input_size).


tied(self,
     bias=False)

Tie encoder weights to dictionary (use D^T as encoder).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

Return

  • self

    • Returns self for method chaining.


untied(self,
       bias=False,
       copy_from_dictionary=True)

Create a new encoder with weight from the current dictionary (or random init).

Parameters

  • bias : bool, optional

    • Whether to include bias in encoder, by default False.

  • copy_from_dictionary : bool, optional

    • If True, initialize encoder with current dictionary weights, by default True.

Return

  • self

    • Returns self for method chaining.


get_dictionary(self)

Return the learned dictionary.

Return

  • torch.Tensor

    • Learned dictionary tensor of shape (nb_components, input_size).


forward(self,
        x)

Perform a forward pass through the autoencoder.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • SAEOuput

    • Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).


train(self,
      mode=True)

Hook called when switching between training and evaluation mode. We use it to ensure no dropout is applied during evaluation.

Parameters

  • mode : bool, optional

    • Whether to set the model to training mode or not, by default True.


encode(self,
       x)

Encode input data with a greedy Matching Pursuit approach.

Parameters

  • x : torch.Tensor

    • Input tensor of shape (batch_size, input_size).

Return

  • residual : torch.Tensor

    • The residual after k Matching Pursuit iterations.

  • codes : torch.Tensor

    • The final sparse codes obtained after k Matching Pursuit iterations.