Base class for samplers

Abstract class.

This is the base Sampler architecture module from which all future samplers should inherit. All samplers are adapted from the Pythae’s samplers implementation for multimodal vaes.

class multivae.samplers.BaseSamplerConfig[source]

BaseSampler config class.

class multivae.samplers.BaseSampler(model, sampler_config=None)[source]

Base class for samplers used to generate from the MultiVae models’ joint latent spaces.

Parameters:
  • model (BaseMultivae) – The model to sample from.

  • sampler_config (BaseSamplerConfig) – An instance of BaseSamplerConfig in which any sampler’s parameters is made available. If None a default configuration is used. Default: None

fit(train_data, **kwargs)[source]

Function to be called to fit the sampler before sampling.

sample(n_samples=1, batch_size=500, return_gen=True)[source]

Main sampling function of the sampler.

Parameters:
  • num_samples (int) – The number of samples to generate

  • batch_size (int) – The batch size to use during sampling

  • return_gen (bool) – Whether the sampler should directly return a the generated data. Default: True.

Returns:

The generated images

Return type:

Tensor

save(dir_path)[source]

Method to save the sampler config. The config is saved a as sampler_config.json file in dir_path.