gt4sd.algorithms.generation.diffusion.implementation module

Implementation details for huggingface diffusers generation algorithms.

Parts of the implementation inspired by: https://github.com/huggingface/diffusers/blob/main/examples/train_unconditional.py.

Summary

Classes:

Generator

Implementation of a generator.

Functions:

set_seed

Set seed for all random number generators.

Reference

set_seed(seed=42)[source]

Set seed for all random number generators.

Parameters

seed (int) – seed to set. Defaults to 42.

Return type

None

class Generator(resources_path, model_type, model_name, scheduler_type, auth_token=True, prompt=None, device=None)[source]

Bases: object

Implementation of a generator.

__init__(resources_path, model_type, model_name, scheduler_type, auth_token=True, prompt=None, device=None)[source]

A Diffusers generation algorithm.

Parameters
  • resources_path (str) – path to the cache.

  • model_type (str) – type of the model.

  • model_name (str) – name of the model weights/version.

  • scheduler_type (str) – type of the schedule.

  • auth_token (bool) – authentication token for private models.

  • prompt (Union[str, Dict[str, Any], None]) – target for conditional generation.

  • device (Union[device, str, None]) – device where the inference is running either as a dedicated class or a string. If not provided is inferred.

load_model()[source]

Load a pretrained diffusion generative model.

Return type

None

sample(number_samples=1)[source]

Sample images with optional conditioning.

Parameters

number_samples (int) – number of images to generate.

Return type

List[Any]

Returns

generated samples.

__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.diffusion.implementation', '__doc__': 'Implementation of a generator.', '__init__': <function Generator.__init__>, 'load_model': <function Generator.load_model>, 'sample': <function Generator.sample>, '__dict__': <attribute '__dict__' of 'Generator' objects>, '__weakref__': <attribute '__weakref__' of 'Generator' objects>, '__annotations__': {}})
__doc__ = 'Implementation of a generator.'
__module__ = 'gt4sd.algorithms.generation.diffusion.implementation'
__weakref__

list of weak references to the object (if defined)