Orthogonal Matching Pursuit SAE (OMPSAE)¶
OMPSAE uses orthogonal matching pursuit for sparse coding. Each iteration picks the atom most correlated with the current residual, then resolves NNLS on all selected atoms to refine the codes. This tighter refit often improves reconstruction over plain matching pursuit. For background, see Sparse Autoencoders via Matching Pursuit.
Basic Usage¶
import torch
from overcomplete.sae import OMPSAE
x = torch.randn(64, 512)
sae = OMPSAE(
input_shape=512,
nb_concepts=4_096,
k=4, # pursuit steps
max_iter=15, # NNLS iterations
dropout=0.1, # optional atom dropout
encoder_module="identity",
device="cuda"
)
residual, codes = sae.encode(x)
Notes:
- encode returns (residual, codes); residual is the reconstruction error after pursuit steps.
- Set dropout to randomly mask atoms each iteration.
- Inputs must be 1D features (no 3D/4D tensors); k and max_iter must be positive.
OMPSAE¶
Orthogonal Matching Pursuit Sparse Autoencoder (OMPSAE).
__init__(self,
input_shape,
nb_concepts,
k=1,
dropout=None,
encoder_module='identity',
dictionary_params=None,
device='cpu',
max_iter=10)¶
input_shape,
nb_concepts,
k=1,
dropout=None,
encoder_module='identity',
dictionary_params=None,
device='cpu',
max_iter=10)
Parameters
-
input_shape : int or tuple of int
Dimensionality of the input data (excluding batch dimension).
-
nb_concepts : int
Number of latent components (atoms) in the dictionary.
-
k : int, optional
Default number of pursuit iterations (must be > 0).
-
dropout : float, optional
Dropout rate applied to dictionary atoms (range [0.0, 1.0]).
-
encoder_module : str or nn.Module, optional
Encoder module or name of registered encoder.
-
dictionary_params : dict, optional
Parameters passed to the dictionary layer.
-
device : str, optional
Device to run the model on (default is 'cpu').
-
max_iter : int, optional
Default number of NNLS iterations (default: 10).
decode(self,
z)¶
z)
Decode latent representation to reconstruct input data.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, input_size).
tied(self,
bias=False)¶
bias=False)
Tie encoder weights to dictionary (use D^T as encoder).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
Return
-
self
Returns self for method chaining.
untied(self,
bias=False,
copy_from_dictionary=True)¶
bias=False,
copy_from_dictionary=True)
Create a new encoder with weight from the current dictionary (or random init).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
-
copy_from_dictionary : bool, optional
If True, initialize encoder with current dictionary weights, by default True.
Return
-
self
Returns self for method chaining.
get_dictionary(self)¶
Return the learned dictionary.
Return
-
torch.Tensor
Learned dictionary tensor of shape (nb_components, input_size).
forward(self,
x)¶
x)
Perform a forward pass through the autoencoder.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
SAEOuput
Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).
train(self,
mode=True)¶
mode=True)
Hook called when switching between training and evaluation mode.
We use it to ensure no dropout is applied during evaluation.
Parameters
-
mode : bool, optional
Whether to set the model to training mode or not, by default True.
encode(self,
x,
k=None,
max_iter=None)¶
x,
k=None,
max_iter=None)
Encode input using Orthogonal Matching Pursuit.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_dim).
-
k : int, optional
Override the number of pursuit iterations.
-
max_iter : int, optional
Override the number of NNLS iterations.
Return
-
residual : torch.Tensor
Final residual after k iterations.
-
codes : torch.Tensor
Sparse codes of shape (batch_size, nb_concepts).