Skip to content

Metrics for Dictionary Learning

The overcomplete.metrics module provides a collection of evaluation metrics designed for dictionary learning algorithms. These metrics help assess sparsity, reconstruction accuracy, and dictionary quality.

Overview

This module includes metrics for:

  • Norm-based evaluations: L0, L1, L2, and Lp norms.
  • Reconstruction losses: Absolute and relative errors.
  • Sparsity metrics: Hoyer score, L1/L2 ratio, and Kappa-4.
  • Dictionary similarity: Hungarian loss, cosine Hungarian loss, and collinearity.
  • Distribution-based metrics: Wasserstein-1D and Fréchet distance.
  • Code analysis: Detecting dead codes and assessing sparsity.

Example Usage (Key Metrics)

from overcomplete.metrics import (dead_codes, r2_score, hungarian_loss,
                                 cosine_hungarian_loss, wasserstein_1d)
# inputs
x = torch.randn(100, 10)
x_hat = torch.randn(100, 10)
# dictionaries
dict1 = torch.randn(512, 256)
dict2 = torch.randn(512, 256)
# concept values (codes)
codes = torch.randn(100, 512)
codes_2 = torch.randn(100, 512)

# check for inactive dictionary elements
dead_code_ratio = dead_codes(codes).mean()

# compare dictionary structures
hungarian_dist = hungarian_loss(dict1, dict2)
cosine_hungarian_dist = cosine_hungarian_loss(dict1, dict2)

# compute reconstruction quality
r2 = r2_score(x, x_hat)
# distrib. reconstruction quality
wasserstein_dist = wasserstein_1d(x, x_hat)

Available Metrics

Norm-Based Metrics

  • l0(x), l1(x), l2(x), lp(x, p)
  • l1_l2_ratio(x): Ratio of L1 to L2 norm.
  • hoyer(x): Normalized sparsity measure.

Reconstruction Losses

  • avg_l2_loss(x, x_hat), avg_l1_loss(x, x_hat)
  • relative_avg_l2_loss(x, x_hat), relative_avg_l1_loss(x, x_hat)
  • r2_score(x, x_hat): Measures reconstruction accuracy.

Sparsity Metrics

  • sparsity(x): Alias for l0(x).
  • sparsity_eps(x, threshold): L0 with an epsilon threshold.
  • kappa_4(x): Kurtosis-based sparsity measure.
  • dead_codes(x): Identifies unused codes in a dictionary.

Dictionary Evaluation

  • hungarian_loss(dict1, dict2): Finds best-matching dictionary elements.
  • cosine_hungarian_loss(dict1, dict2): Cosine distance-based Hungarian loss.
  • dictionary_collinearity(dict): Measures collinearity in dictionary elements.

Distribution-Based Metrics

  • wasserstein_1d(x1, x2): 1D Wasserstein-1 distance.
  • frechet_distance(x1, x2): Fréchet distance for distributions.

For further details, refer to the module documentation.

l2(v,
   dims=None)

Compute the L2 norm, across 'dims'.

Parameters

  • v : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the L2 norm, by default None.

Return

  • torch.Tensor

    • L2 norm of v if dims=None else L2 norm across dims.


l1(v,
   dims=None)

Compute the L1 norm, across 'dims'.

Parameters

  • v : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the L1 norm, by default None.

Return

  • torch.Tensor

    • L1 norm of v if dims=None else L1 for across dims.


lp(v,
   p=0.5,
   dims=None)

Compute the Lp norm, across 'dims'.

Parameters

  • v : torch.Tensor

    • Input tensor.

  • p : float, optional

    • Power of the norm, by default 0.5.

  • dims : tuple, optional

    • Dimensions over which to compute the Lp norm, by default None.

Return

  • torch.Tensor

    • Lp norm of v if dims=None else Lp norm across dims.


avg_l2_loss(x,
            x_hat)

Compute the L2 loss, averaged across samples.

Parameters

  • x : torch.Tensor

    • Original input tensor of shape (batch_size, d).

  • x_hat : torch.Tensor

    • Reconstructed input tensor of shape (batch_size, d).

Return

  • float

    • Average L2 loss per sample.


avg_l1_loss(x,
            x_hat)

Compute the L1 loss, averaged across samples.

Parameters

  • x : torch.Tensor

    • Original input tensor of shape (batch_size, d).

  • x_hat : torch.Tensor

    • Reconstructed input tensor of shape (batch_size, d).

Return

  • float

    • Average L1 loss per sample.


relative_avg_l2_loss(x,
                     x_hat,
                     epsilon=1e-06)

Compute the relative reconstruction loss, average across samples.

Parameters

  • x : torch.Tensor

    • Original input tensor of shape (batch_size, d).

  • x_hat : torch.Tensor

    • Reconstructed input tensor of shape (batch_size, d).

  • epsilon : float, optional

    • Small value to avoid division by zero, by default 1e-6.

Return

  • float

    • Average relative L2 loss per sample.


relative_avg_l1_loss(x,
                     x_hat,
                     epsilon=1e-06)

Compute the relative reconstruction loss, average across samples.

Parameters

  • x : torch.Tensor

    • Original input tensor of shape (batch_size, d).

  • x_hat : torch.Tensor

    • Reconstructed input tensor of shape (batch_size, d).

  • epsilon : float, optional

    • Small value to avoid division by zero, by default 1e-6.

Return

  • float

    • Average relative L1 loss per sample.


l0(x,
   dims=None)

Compute the average number of zero elements.

Parameters

  • x : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the sparsity, by default None.

Return

  • torch.Tensor

    • Average sparsity if dims=None else sparsity across dims.


l0(x,
   dims=None)

Compute the average number of zero elements.

Parameters

  • x : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the sparsity, by default None.

Return

  • torch.Tensor

    • Average sparsity if dims=None else sparsity across dims.


l1_l2_ratio(x,
            dims=-1)

Compute the L1/L2 ratio of a tensor. By default, the ratio is computed across the last dimension. This score is a useful metric to evaluate the sparsity of a tensor. It is however sensitive to the dimensions of the data, for an unbiased metric, consider using the Hoyer score.

Parameters

  • x : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the ratio, by default -1.

Return

  • torch.Tensor

    • the l1/l2 ratio.


hoyer(x)

Compute the Hoyer sparsity of a tensor. The hoyer score include the dimension normalization factor. A score of 1 indicates a perfectly sparse representation, while a score of 0 indicates a dense representation.

Parameters

  • x : torch.Tensor

    • A 2D tensor of shape (batch_size, d).

Return

  • torch.Tensor (batch_size,)

    • Hoyer sparsity for each vector in the batch.


kappa_4(x)

Compute the Kappa-4 sparsity of a tensor. The Kappa-4 score is a metric to evaluate the sparsity of a distribution. It is the kurtosis, which measure the "peakedness" of a distribution.

Parameters

  • x : torch.Tensor

    • Input tensor.

  • dims : tuple, optional

    • Dimensions over which to compute the ratio, by default -1.

Return

  • torch.Tensor

    • the Kappa-4 sparsity.


r2_score(x,
         x_hat)

Compute the R^2 score (coefficient of determination) for the reconstruction. A score of 1 indicates a perfect reconstruction while a score of 0 indicates that the reconstruction is as good as the mean.

Parameters

  • x : torch.Tensor

    • Original input tensor of shape (batch_size, d).

  • x_hat : torch.Tensor

    • Reconstructed input tensor of shape (batch_size, d).

Return

  • float

    • R^2 score.


dead_codes(z)

Check for codes that never fire and return the percentage of codes that never fire.

Parameters

  • z : torch.Tensor

    • Input tensor of shape (batch_size, num_codes).

Return

  • torch.Tensor

    • Tensor indicating which codes are dead.


hungarian_loss(dictionary1,
               dictionary2,
               p_norm=2)

Compute the Hungarian loss between two dictionaries.

Parameters

  • dictionary1 : torch.Tensor

    • First dictionary tensor of shape (num_codes, dim).

  • dictionary2 : torch.Tensor

    • Second dictionary tensor of shape (num_codes, dim).

  • p_norm : int, optional

    • Norm to use for computing the distance, by default 2.

Return

  • float

    • Hungarian loss.


cosine_hungarian_loss(dictionary1,
                      dictionary2)

Compute the cosine Hungarian loss between two dictionaries. A score of 0 indicates that the two dictionaries are identical up to a permutation. A score of 'dim' indicates that the two dictionaries are orthogonal. To have a normalized score, we recommend to divide the score by the dimension of the dictionary.

Parameters

  • dictionary1 : torch.Tensor

    • First dictionary tensor of shape (num_codes, dim).

  • dictionary2 : torch.Tensor

    • Second dictionary tensor of shape (num_codes, dim).

Return

  • float

    • Cosine Hungarian loss.


dictionary_collinearity(dictionary)

Compute the collinearity of a dictionary.

Parameters

  • dictionary : torch.Tensor

    • Dictionary tensor of shape (num_codes, dim).

Return

  • max_collinearity : float

    • Maximum collinearity across dictionary elements (non diagonal).

  • cosine_similarity_matrix : torch.Tensor

    • Matrix of cosine similarities across dictionary elements.


wasserstein_1d(x1,
               x2)

Compute the 1D Wasserstein-1 distance between two sets of codes and average across dimensions.

Parameters

  • x1 : torch.Tensor

    • First set of samples of shape (num_samples, d).

  • x2 : torch.Tensor

    • Second set of samples of shape (num_samples, d).

Return

  • torch.Tensor

    • Wasserstein distance.


frechet_distance(x1,
                 x2)

Compute the Fréchet distance (Wasserstein-2 distance) between two sets of activations. Assume that the activations are normally distributed.

Parameters

  • x1 : torch.Tensor

    • First set of samples of shape (num_samples, d).

  • x2 : torch.Tensor

    • Second set of samples of shape (num_samples, d).

Return

  • torch.Tensor

    • Fréchet distance.


codes_correlation_matrix(codes)

Compute the correlation matrix of codes.

Parameters

  • codes : torch.Tensor

    • Codes tensor of shape (batch_size, num_codes).

Return

  • max_corr : float

    • Maximum correlation across codes (non diagonal).

  • corrs : torch.Tensor

    • Correlation matrix of codes.


energy_of_codes(codes,
                dictionary)

Compute the energy of codes given a dictionary. for example, with X input sample, Z the codes and D the dictionary: X = ZD, Energy(Z) = || E[Z]D ||^2 and correspond to the average energy the codes bring to the reconstruction.

Parameters

  • codes : torch.Tensor

    • Codes tensor of shape (batch_size, num_codes).

  • dictionary : torch.Tensor

    • Dictionary tensor of shape (num_codes, dim).

Return

  • torch.Tensor

    • Energy of codes, one per codes dimension.