gt4sd.training_pipelines.moses.vae.core module

Moses VAE training pipeline.

Summary

Classes:

MosesVAEModelArguments

Arguments related to Moses VAE model.

MosesVAETrainingArguments

Arguments related to Moses VAE training.

MosesVAETrainingPipeline

Moses VAE training pipelines.

Reference

class MosesVAETrainingPipeline[source]

Bases: MosesTrainingPipeline

Moses VAE training pipelines.

train(training_args, model_args, dataset_args)[source]

Generic training function for Moses VAE training.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

None

__annotations__ = {}
__doc__ = 'Moses VAE training pipelines.'
__module__ = 'gt4sd.training_pipelines.moses.vae.core'
class MosesVAEModelArguments(q_cell='gru', q_bidir=True, q_d_h=256, q_n_layers=1, q_dropout=0.5, d_cell='gru', d_n_layers=3, d_dropout=0, d_z=128, d_d_h=512, freeze_embeddings=False)[source]

Bases: TrainingPipelineArguments

Arguments related to Moses VAE model.

__name__ = 'MosesVAEModelArguments'
q_cell: str = 'gru'
q_bidir: bool = True
q_d_h: int = 256
q_n_layers: int = 1
q_dropout: float = 0.5
d_cell: str = 'gru'
d_n_layers: int = 3
d_dropout: float = 0
d_z: int = 128
d_d_h: int = 512
freeze_embeddings: bool = False
__annotations__ = {'d_cell': <class 'str'>, 'd_d_h': <class 'int'>, 'd_dropout': <class 'float'>, 'd_n_layers': <class 'int'>, 'd_z': <class 'int'>, 'freeze_embeddings': <class 'bool'>, 'q_bidir': <class 'bool'>, 'q_cell': <class 'str'>, 'q_d_h': <class 'int'>, 'q_dropout': <class 'float'>, 'q_n_layers': <class 'int'>}
__dataclass_fields__ = {'d_cell': Field(name='d_cell',type=<class 'str'>,default='gru',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder rnn cell type.'}),kw_only=False,_field_type=_FIELD), 'd_d_h': Field(name='d_d_h',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Latent vector dimensionality'}),kw_only=False,_field_type=_FIELD), 'd_dropout': Field(name='d_dropout',type=<class 'float'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder layers dropout'}),kw_only=False,_field_type=_FIELD), 'd_n_layers': Field(name='d_n_layers',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder number of layers.'}),kw_only=False,_field_type=_FIELD), 'd_z': Field(name='d_z',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Latent vector dimensionality'}),kw_only=False,_field_type=_FIELD), 'freeze_embeddings': Field(name='freeze_embeddings',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If to freeze embeddings while training'}),kw_only=False,_field_type=_FIELD), 'q_bidir': Field(name='q_bidir',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to add second direction in the encoder.'}),kw_only=False,_field_type=_FIELD), 'q_cell': Field(name='q_cell',type=<class 'str'>,default='gru',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder rnn cell type.'}),kw_only=False,_field_type=_FIELD), 'q_d_h': Field(name='q_d_h',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder h dimensionality.'}),kw_only=False,_field_type=_FIELD), 'q_dropout': Field(name='q_dropout',type=<class 'float'>,default=0.5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder layers dropout.'}),kw_only=False,_field_type=_FIELD), 'q_n_layers': Field(name='q_n_layers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder number of layers.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to Moses VAE model.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(q_cell='gru', q_bidir=True, q_d_h=256, q_n_layers=1, q_dropout=0.5, d_cell='gru', d_n_layers=3, d_dropout=0, d_z=128, d_d_h=512, freeze_embeddings=False)
__match_args__ = ('q_cell', 'q_bidir', 'q_d_h', 'q_n_layers', 'q_dropout', 'd_cell', 'd_n_layers', 'd_dropout', 'd_z', 'd_d_h', 'freeze_embeddings')
__module__ = 'gt4sd.training_pipelines.moses.vae.core'
__repr__()

Return repr(self).

class MosesVAETrainingArguments(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu', n_batch=512, grad_clipping=50, kl_start=0, kl_w_start=0, kl_w_end=0.05, lr_start=0.00030000000000000003, lr_n_period=10, lr_n_restarts=10, lr_n_mult=1, lr_end=0.00030000000000000003, n_last=1000, n_jobs=1, n_workers=1, warm_start='')[source]

Bases: MosesTrainingArguments

Arguments related to Moses VAE training.

n_batch: int = 512
grad_clipping: int = 50
kl_start: int = 0
kl_w_start: float = 0
kl_w_end: float = 0.05
lr_start: float = 0.00030000000000000003
lr_n_period: int = 10
lr_n_restarts: int = 10
lr_n_mult: int = 1
lr_end: float = 0.00030000000000000003
n_last: int = 1000
n_jobs: int = 1
n_workers: int = 1
warm_start: str = ''
__annotations__ = {'config_save': 'str', 'device': 'str', 'grad_clipping': <class 'int'>, 'kl_start': <class 'int'>, 'kl_w_end': <class 'float'>, 'kl_w_start': <class 'float'>, 'log_file': 'str', 'lr_end': <class 'float'>, 'lr_n_mult': <class 'int'>, 'lr_n_period': <class 'int'>, 'lr_n_restarts': <class 'int'>, 'lr_start': <class 'float'>, 'model_save': 'str', 'n_batch': <class 'int'>, 'n_jobs': <class 'int'>, 'n_last': <class 'int'>, 'n_workers': <class 'int'>, 'save_frequency': 'int', 'seed': 'int', 'vocab_save': 'str', 'warm_start': <class 'str'>}
__dataclass_fields__ = {'config_save': Field(name='config_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path for the config.'}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Device to run: 'cpu' or 'cuda:<device number>'"}),kw_only=False,_field_type=_FIELD), 'grad_clipping': Field(name='grad_clipping',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradients clipping size.'}),kw_only=False,_field_type=_FIELD), 'kl_start': Field(name='kl_start',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Epoch to start change kl weight from.'}),kw_only=False,_field_type=_FIELD), 'kl_w_end': Field(name='kl_w_end',type=<class 'float'>,default=0.05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum kl weight value.'}),kw_only=False,_field_type=_FIELD), 'kl_w_start': Field(name='kl_w_start',type=<class 'float'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial kl weight value.'}),kw_only=False,_field_type=_FIELD), 'log_file': Field(name='log_file',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save the the logs.'}),kw_only=False,_field_type=_FIELD), 'lr_end': Field(name='lr_end',type=<class 'float'>,default=0.00030000000000000003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum lr weight value.'}),kw_only=False,_field_type=_FIELD), 'lr_n_mult': Field(name='lr_n_mult',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Mult coefficient after restart in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_n_period': Field(name='lr_n_period',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Epochs before first restart in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_n_restarts': Field(name='lr_n_restarts',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of restarts in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_start': Field(name='lr_start',type=<class 'float'>,default=0.00030000000000000003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial lr value.'}),kw_only=False,_field_type=_FIELD), 'model_save': Field(name='model_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the trained model is saved.'}),kw_only=False,_field_type=_FIELD), 'n_batch': Field(name='n_batch',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size.'}),kw_only=False,_field_type=_FIELD), 'n_jobs': Field(name='n_jobs',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of threads.'}),kw_only=False,_field_type=_FIELD), 'n_last': Field(name='n_last',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of iters to smooth loss calc.'}),kw_only=False,_field_type=_FIELD), 'n_workers': Field(name='n_workers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of workers.'}),kw_only=False,_field_type=_FIELD), 'save_frequency': Field(name='save_frequency',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to save the model.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Seed used for random number generation.'}),kw_only=False,_field_type=_FIELD), 'vocab_save': Field(name='vocab_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to save the model vocabulary.'}),kw_only=False,_field_type=_FIELD), 'warm_start': Field(name='warm_start',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to a folder to warm start from. Set empty string to not use.This has to contain files `model.pt`, `vocab.pt` and `config.pt`.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to Moses VAE training.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu', n_batch=512, grad_clipping=50, kl_start=0, kl_w_start=0, kl_w_end=0.05, lr_start=0.00030000000000000003, lr_n_period=10, lr_n_restarts=10, lr_n_mult=1, lr_end=0.00030000000000000003, n_last=1000, n_jobs=1, n_workers=1, warm_start='')
__match_args__ = ('model_save', 'log_file', 'config_save', 'vocab_save', 'save_frequency', 'seed', 'device', 'n_batch', 'grad_clipping', 'kl_start', 'kl_w_start', 'kl_w_end', 'lr_start', 'lr_n_period', 'lr_n_restarts', 'lr_n_mult', 'lr_end', 'n_last', 'n_jobs', 'n_workers', 'warm_start')
__module__ = 'gt4sd.training_pipelines.moses.vae.core'
__repr__()

Return repr(self).