gt4sd.frameworks.granular.ml.models.utils module

Model utilities.

Summary

Classes:

KLAnnealer

Annealer scaling KL weights (beta) linearly according to the number of epochs.

Reference

class KLAnnealer(kl_low, kl_high, n_epochs, start_epoch)[source]

Bases: object

Annealer scaling KL weights (beta) linearly according to the number of epochs.

__init__(kl_low, kl_high, n_epochs, start_epoch)[source]

Construct KLAnnealer.

Parameters
  • kl_low (float) – low KL weight.

  • kl_high (float) – high KL weight.

  • n_epochs (int) – number of epochs.

  • start_epoch (int) – starting epoch.

__call__(epoch)[source]

Call the annealer.

Parameters

epoch (int) – current epoch number.

Return type

float

Returns

the beta weight.

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.granular.ml.models.utils', '__doc__': 'Annealer scaling KL weights (beta) linearly according to the number of epochs.', '__init__': <function KLAnnealer.__init__>, '__call__': <function KLAnnealer.__call__>, '__dict__': <attribute '__dict__' of 'KLAnnealer' objects>, '__weakref__': <attribute '__weakref__' of 'KLAnnealer' objects>, '__annotations__': {}})
__doc__ = 'Annealer scaling KL weights (beta) linearly according to the number of epochs.'
__module__ = 'gt4sd.frameworks.granular.ml.models.utils'
__weakref__

list of weak references to the object (if defined)