gt4sd.frameworks.granular.ml.module module

Model combiner module.

Summary

Classes:

GranularModule

Module from granular.

Reference

class GranularModule(architecture_autoencoders, architecture_latent_models, lr=0.0001, test_output_path='./test', **kwargs)[source]

Bases: LightningModule

Module from granular.

__init__(architecture_autoencoders, architecture_latent_models, lr=0.0001, test_output_path='./test', **kwargs)[source]

Construct GranularModule.

Parameters
  • architecture_autoencoders (List[Dict[str, Any]]) – list of autoencoder architecture configurations.

  • architecture_latent_models (List[Dict[str, Any]]) – list of latent model architecture configurations.

  • lr (float) – learning rate for Adam optimizer. Defaults to 1e-4.

  • test_output_path (str) – path where to save latent encodings and predictions for the test set when an epoch ends. Defaults to a a folder called “test” in the current working directory.

_autoencoder_step(batch, model, model_step_fn)[source]

Autoencoder module forward pass.

Parameters
  • batch (Any) – batch representation.

  • model (GranularEncoderDecoderModel) – a module.

  • model_step_fn (Callable) – callable for the step.

Return type

Tuple[Any, Any, Any]

Returns

a tuple containing the latent representation, the loss and the logs for the module.

_latent_step(batch, model, model_step_fn, z)[source]

Latent module forward pass.

Parameters
  • batch (Any) – batch representation.

  • model (GranularBaseModel) – a module.

  • model_step_fn (Callable) – callable for the step.

  • z (Dict[int, Any]) – latent encodings.

Return type

Tuple[Any, Any, Any]

Returns

a tuple containing the latent step ouput, the loss and the logs for the module.

training_step(batch, *args, **kwargs)[source]

Training step implementation.

Parameters

batch (Any) – batch representation.

Return type

Dict[str, Any]

Returns

loss and logs.

validation_step(batch, *args, **kwargs)[source]

Validation step implementation.

Parameters

batch (Any) – batch representation.

Return type

Dict[str, Any]

Returns

loss and logs.

configure_optimizers()[source]

Configure optimizers.

Return type

Optimizer

Returns

an optimizer, currently only Adam is supported.

__annotations__ = {}
__doc__ = 'Module from granular.'
__module__ = 'gt4sd.frameworks.granular.ml.module'