FID

This module allow to easily compute FID metrics on a MultiVae model.

We are grateful to https://github.com/mseitzer/pytorch-fid, on which our code is heavily based.

A simple example:

from multivae.metrics import FIDEvaluator, FIDEvaluatorConfig

fid_config = FIDEvaluatorConfig(batch_size=128,
                                inception_weights_path='your_path',
                                wandb_path='your_wandb_path' #optional / to log to wandb
                                )

fid_module = FIDEvaluator(
    model=your_model,
    test_dataset=test_data,
    output='your_ouput_path', # where to save metrics
    sampler= None,# you can pass a trained MultiVae sampler
    custom_encoders=None,# If you wish to use custom networks for each modality rather than the inception network
)

# Compute FID for unconditional generation
fid_module.eval()

# Compute FID for conditional generation
fid_module.compute_all_conditional_fids(gen_mod = 'modality_to_generate')
class multivae.metrics.FIDEvaluatorConfig(batch_size=512, wandb_path=None, inception_weights_path='../fid_model/model.pt', dims_inception=2048)[source]

Config class for the evaluation of the coherences module.

Parameters:
  • batch_size (int) – The batch size to use in the evaluation. Default to 512

  • wandb_path (str) – The user can provide the path of the wandb run with a format ‘entity/projet_name/run_id’ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.

  • inception_weights_path (str) – The path to InceptionV3 weights. Default to ‘../fid_model/model.pt’.

  • dims_inception (int) – Select the embedding layer of the Inception network defined by its output_size. Default to 2048.

class multivae.metrics.FIDEvaluator(model, test_dataset, output=None, eval_config=FIDEvaluatorConfig(name='FIDEvaluatorConfig', batch_size=512, wandb_path=None, inception_weights_path='../fid_model/model.pt', dims_inception=2048), sampler=None, custom_encoders=None, transform=None)[source]

Class for computing Fréchet inception distance (FID) metrics.

Parameters:
  • model (BaseMultiVAE) – The model to evaluate.

  • test_dataset (MultimodalBaseDataset) – The dataset to use for computing the metrics.

  • output (str) – The folder path to save metrics. The metrics will be saved in a metrics.txt file.

  • eval_config (FIDEvaluatorConfig) – The configuration class to specify parameters for the evaluation.

  • sampler (Basesampler) – The sampler used to generate from the latent space. If None is provided, the latent codes are generated from prior. Default to None.

  • custom_encoders (Dict[str,torch.nn.Module]) – If you desire, you can provide our own embedding architectures to use instead of the InceptionV3 model to compute Fréchet Distances. By default, the pretrained InceptionV3 network is used for all modalities. Default to None.

  • transform (torchvision.Transforms) – To apply to the images before computing the embeddings. If None is provided a default resizing to (3,299,299) is applied. Default to None.

calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-06)[source]

Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians \(X_1 \sim \mathcal{N}(\mu_1, C_1)\) and \(X_2 \sim \mathcal{N}(\mu_2, C_2)\) is \(d^2 = \lVert \mu_1 - \mu_2\rVert^2 + \mathrm{Tr}(C_1 + C_2 - 2\sqrt{(C_1\cdot C_2)})\). Stable version by Dougal J. Sutherland.

Parameters:
  • mu1 (numpy.ndarray) – Numpy array containing the activations of a layer of the inception net (like returned by the function ‘get_predictions’) for generated samples.

  • mu2 (numpy.ndarray) – The sample mean over activations, precalculated on an representative data set.

  • sigma1 (numpy.ndarray) – The covariance matrix over activations for generated samples.

  • sigma2 (numpy.ndarray) – The covariance matrix over activations, precalculated on an representative data set.

Returns:

The Frechet Distance.

Return type:

numpy.ndarray

compute_all_conditional_fids(gen_mod)[source]

For all subsets in modalities gen_mod, compute the FID when generating images from the subsets.

compute_fid_from_conditional_generation(subset, gen_mod)[source]

Generate samples from the conditional distribution conditioned on subset and compute Frechet distance for gen_mod.

get_frechet_distance(mod, generate_latent_function)[source]

Calculates the activations of the pool_3 layer for all images.

unconditional_fids()[source]

Generate data from prior or sampler fitted in the latent space and compute the FID for each modality.

Returns:

FIDs for all modalities.

Return type:

ModelOutput