gt4sd.frameworks.granular.ml.models.utils module¶
Model utilities.
Summary¶
Classes:
Annealer scaling KL weights (beta) linearly according to the number of epochs. |
Reference¶
- class KLAnnealer(kl_low, kl_high, n_epochs, start_epoch)[source]¶
Bases:
object
Annealer scaling KL weights (beta) linearly according to the number of epochs.
- __init__(kl_low, kl_high, n_epochs, start_epoch)[source]¶
Construct KLAnnealer.
- Parameters
kl_low (
float
) – low KL weight.kl_high (
float
) – high KL weight.n_epochs (
int
) – number of epochs.start_epoch (
int
) – starting epoch.
- __call__(epoch)[source]¶
Call the annealer.
- Parameters
epoch (
int
) – current epoch number.- Return type
float
- Returns
the beta weight.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.granular.ml.models.utils', '__doc__': 'Annealer scaling KL weights (beta) linearly according to the number of epochs.', '__init__': <function KLAnnealer.__init__>, '__call__': <function KLAnnealer.__call__>, '__dict__': <attribute '__dict__' of 'KLAnnealer' objects>, '__weakref__': <attribute '__weakref__' of 'KLAnnealer' objects>, '__annotations__': {}})¶
- __doc__ = 'Annealer scaling KL weights (beta) linearly according to the number of epochs.'¶
- __module__ = 'gt4sd.frameworks.granular.ml.models.utils'¶
- __weakref__¶
list of weak references to the object (if defined)