Matching Pursuit SAE (MpSAE)¶
MpSAE replaces thresholding with a greedy matching pursuit loop: at each step it picks the atom most correlated with the residual, updates the codes, and subtracts the atom’s contribution, yielding sparse codes that track reconstruction progress. We encourage reading 1 for the full method.
Basic Usage¶
from overcomplete import MpSAE
# define a Matching Pursuit SAE with input dimension 512, 4k concepts
sae = MpSAE(512, 4_096, k=4, dropout=0.1)
# k = number of pursuit steps, dropout optionally masks atoms each step
residual, codes = sae.encode(x)
Advanced: auxiliary loss to revive dead codes¶
To ensure high dictionary utilization in MP-SAE, we strongly recommend implementing an auxiliary loss term. Here is an example of such loss:
def criterion(x, x_hat, residual, z, d):
recon_loss = ((x - x_hat) ** 2).mean()
revive_mask = (z.amax(dim=0) < 1e-2).detach() # shape: [c]
if revive_mask.sum() > 10:
projected = residual @ d.T # shape: [n, c]
revive_term = projected[:, revive_mask].mean()
recon_loss -= revive_term * 1e-2
return recon_loss
MpSAE¶
Matching Pursuit Sparse Autoencoder (MpSAE).
__init__(self,
input_shape,
nb_concepts,
k=1,
dropout=None,
encoder_module='identity',
dictionary_params=None,
device='cpu')¶
input_shape,
nb_concepts,
k=1,
dropout=None,
encoder_module='identity',
dictionary_params=None,
device='cpu')
Parameters
-
input_shape : int or tuple of int
Dimensionality of the input data (excluding the batch dimension).
-
nb_concepts : int
Number of latent dimensions (components) of the autoencoder.
-
k : int, optional
The number of matching pursuit iterations to perform (must be > 0).
-
dropout : float, optional
Probability of dropping a dictionary element at each iteration (range: 0.0 - 1.0). If None, no dropout is applied.
-
encoder_module : nn.Module or str, optional
Custom encoder module (or its registered name). If None, a default encoder is used.
-
dictionary_params : dict, optional
Parameters that will be passed to the dictionary layer.
See DictionaryLayer for more details.
-
device : str, optional
Device on which to run the model (default is 'cpu').
decode(self,
z)¶
z)
Decode latent representation to reconstruct input data.
Parameters
-
z : torch.Tensor
Latent representation tensor of shape (batch_size, nb_components).
Return
-
torch.Tensor
Reconstructed input tensor of shape (batch_size, input_size).
tied(self,
bias=False)¶
bias=False)
Tie encoder weights to dictionary (use D^T as encoder).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
Return
-
self
Returns self for method chaining.
untied(self,
bias=False,
copy_from_dictionary=True)¶
bias=False,
copy_from_dictionary=True)
Create a new encoder with weight from the current dictionary (or random init).
Parameters
-
bias : bool, optional
Whether to include bias in encoder, by default False.
-
copy_from_dictionary : bool, optional
If True, initialize encoder with current dictionary weights, by default True.
Return
-
self
Returns self for method chaining.
get_dictionary(self)¶
Return the learned dictionary.
Return
-
torch.Tensor
Learned dictionary tensor of shape (nb_components, input_size).
forward(self,
x)¶
x)
Perform a forward pass through the autoencoder.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
SAEOuput
Return the pre_codes (z_pre), codes (z) and reconstructed input tensor (x_hat).
train(self,
mode=True)¶
mode=True)
Hook called when switching between training and evaluation mode.
We use it to ensure no dropout is applied during evaluation.
Parameters
-
mode : bool, optional
Whether to set the model to training mode or not, by default True.
encode(self,
x)¶
x)
Encode input data with a greedy Matching Pursuit approach.
Parameters
-
x : torch.Tensor
Input tensor of shape (batch_size, input_size).
Return
-
residual : torch.Tensor
The residual after k Matching Pursuit iterations.
-
codes : torch.Tensor
The final sparse codes obtained after k Matching Pursuit iterations.