gt4sd.training_pipelines.moses.vae.core module¶
Moses VAE training pipeline.
Summary¶
Classes:
Arguments related to Moses VAE model. |
|
Arguments related to Moses VAE training. |
|
Moses VAE training pipelines. |
Reference¶
- class MosesVAETrainingPipeline[source]¶
Bases:
MosesTrainingPipeline
Moses VAE training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
Generic training function for Moses VAE training.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'Moses VAE training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.moses.vae.core'¶
- class MosesVAEModelArguments(q_cell='gru', q_bidir=True, q_d_h=256, q_n_layers=1, q_dropout=0.5, d_cell='gru', d_n_layers=3, d_dropout=0, d_z=128, d_d_h=512, freeze_embeddings=False)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to Moses VAE model.
- __name__ = 'MosesVAEModelArguments'¶
- q_cell: str = 'gru'¶
- q_bidir: bool = True¶
- q_d_h: int = 256¶
- q_n_layers: int = 1¶
- q_dropout: float = 0.5¶
- d_cell: str = 'gru'¶
- d_n_layers: int = 3¶
- d_dropout: float = 0¶
- d_z: int = 128¶
- d_d_h: int = 512¶
- freeze_embeddings: bool = False¶
- __annotations__ = {'d_cell': <class 'str'>, 'd_d_h': <class 'int'>, 'd_dropout': <class 'float'>, 'd_n_layers': <class 'int'>, 'd_z': <class 'int'>, 'freeze_embeddings': <class 'bool'>, 'q_bidir': <class 'bool'>, 'q_cell': <class 'str'>, 'q_d_h': <class 'int'>, 'q_dropout': <class 'float'>, 'q_n_layers': <class 'int'>}¶
- __dataclass_fields__ = {'d_cell': Field(name='d_cell',type=<class 'str'>,default='gru',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder rnn cell type.'}),kw_only=False,_field_type=_FIELD), 'd_d_h': Field(name='d_d_h',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Latent vector dimensionality'}),kw_only=False,_field_type=_FIELD), 'd_dropout': Field(name='d_dropout',type=<class 'float'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder layers dropout'}),kw_only=False,_field_type=_FIELD), 'd_n_layers': Field(name='d_n_layers',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder number of layers.'}),kw_only=False,_field_type=_FIELD), 'd_z': Field(name='d_z',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Latent vector dimensionality'}),kw_only=False,_field_type=_FIELD), 'freeze_embeddings': Field(name='freeze_embeddings',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If to freeze embeddings while training'}),kw_only=False,_field_type=_FIELD), 'q_bidir': Field(name='q_bidir',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to add second direction in the encoder.'}),kw_only=False,_field_type=_FIELD), 'q_cell': Field(name='q_cell',type=<class 'str'>,default='gru',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder rnn cell type.'}),kw_only=False,_field_type=_FIELD), 'q_d_h': Field(name='q_d_h',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder h dimensionality.'}),kw_only=False,_field_type=_FIELD), 'q_dropout': Field(name='q_dropout',type=<class 'float'>,default=0.5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder layers dropout.'}),kw_only=False,_field_type=_FIELD), 'q_n_layers': Field(name='q_n_layers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Encoder number of layers.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to Moses VAE model.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(q_cell='gru', q_bidir=True, q_d_h=256, q_n_layers=1, q_dropout=0.5, d_cell='gru', d_n_layers=3, d_dropout=0, d_z=128, d_d_h=512, freeze_embeddings=False)¶
- __match_args__ = ('q_cell', 'q_bidir', 'q_d_h', 'q_n_layers', 'q_dropout', 'd_cell', 'd_n_layers', 'd_dropout', 'd_z', 'd_d_h', 'freeze_embeddings')¶
- __module__ = 'gt4sd.training_pipelines.moses.vae.core'¶
- __repr__()¶
Return repr(self).
- class MosesVAETrainingArguments(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu', n_batch=512, grad_clipping=50, kl_start=0, kl_w_start=0, kl_w_end=0.05, lr_start=0.00030000000000000003, lr_n_period=10, lr_n_restarts=10, lr_n_mult=1, lr_end=0.00030000000000000003, n_last=1000, n_jobs=1, n_workers=1, warm_start='')[source]¶
Bases:
MosesTrainingArguments
Arguments related to Moses VAE training.
- n_batch: int = 512¶
- grad_clipping: int = 50¶
- kl_start: int = 0¶
- kl_w_start: float = 0¶
- kl_w_end: float = 0.05¶
- lr_start: float = 0.00030000000000000003¶
- lr_n_period: int = 10¶
- lr_n_restarts: int = 10¶
- lr_n_mult: int = 1¶
- lr_end: float = 0.00030000000000000003¶
- n_last: int = 1000¶
- n_jobs: int = 1¶
- n_workers: int = 1¶
- warm_start: str = ''¶
- __annotations__ = {'config_save': 'str', 'device': 'str', 'grad_clipping': <class 'int'>, 'kl_start': <class 'int'>, 'kl_w_end': <class 'float'>, 'kl_w_start': <class 'float'>, 'log_file': 'str', 'lr_end': <class 'float'>, 'lr_n_mult': <class 'int'>, 'lr_n_period': <class 'int'>, 'lr_n_restarts': <class 'int'>, 'lr_start': <class 'float'>, 'model_save': 'str', 'n_batch': <class 'int'>, 'n_jobs': <class 'int'>, 'n_last': <class 'int'>, 'n_workers': <class 'int'>, 'save_frequency': 'int', 'seed': 'int', 'vocab_save': 'str', 'warm_start': <class 'str'>}¶
- __dataclass_fields__ = {'config_save': Field(name='config_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path for the config.'}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Device to run: 'cpu' or 'cuda:<device number>'"}),kw_only=False,_field_type=_FIELD), 'grad_clipping': Field(name='grad_clipping',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradients clipping size.'}),kw_only=False,_field_type=_FIELD), 'kl_start': Field(name='kl_start',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Epoch to start change kl weight from.'}),kw_only=False,_field_type=_FIELD), 'kl_w_end': Field(name='kl_w_end',type=<class 'float'>,default=0.05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum kl weight value.'}),kw_only=False,_field_type=_FIELD), 'kl_w_start': Field(name='kl_w_start',type=<class 'float'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial kl weight value.'}),kw_only=False,_field_type=_FIELD), 'log_file': Field(name='log_file',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save the the logs.'}),kw_only=False,_field_type=_FIELD), 'lr_end': Field(name='lr_end',type=<class 'float'>,default=0.00030000000000000003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum lr weight value.'}),kw_only=False,_field_type=_FIELD), 'lr_n_mult': Field(name='lr_n_mult',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Mult coefficient after restart in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_n_period': Field(name='lr_n_period',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Epochs before first restart in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_n_restarts': Field(name='lr_n_restarts',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of restarts in SGDR.'}),kw_only=False,_field_type=_FIELD), 'lr_start': Field(name='lr_start',type=<class 'float'>,default=0.00030000000000000003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial lr value.'}),kw_only=False,_field_type=_FIELD), 'model_save': Field(name='model_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the trained model is saved.'}),kw_only=False,_field_type=_FIELD), 'n_batch': Field(name='n_batch',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size.'}),kw_only=False,_field_type=_FIELD), 'n_jobs': Field(name='n_jobs',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of threads.'}),kw_only=False,_field_type=_FIELD), 'n_last': Field(name='n_last',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of iters to smooth loss calc.'}),kw_only=False,_field_type=_FIELD), 'n_workers': Field(name='n_workers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of workers.'}),kw_only=False,_field_type=_FIELD), 'save_frequency': Field(name='save_frequency',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to save the model.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Seed used for random number generation.'}),kw_only=False,_field_type=_FIELD), 'vocab_save': Field(name='vocab_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to save the model vocabulary.'}),kw_only=False,_field_type=_FIELD), 'warm_start': Field(name='warm_start',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to a folder to warm start from. Set empty string to not use.This has to contain files `model.pt`, `vocab.pt` and `config.pt`.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to Moses VAE training.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu', n_batch=512, grad_clipping=50, kl_start=0, kl_w_start=0, kl_w_end=0.05, lr_start=0.00030000000000000003, lr_n_period=10, lr_n_restarts=10, lr_n_mult=1, lr_end=0.00030000000000000003, n_last=1000, n_jobs=1, n_workers=1, warm_start='')¶
- __match_args__ = ('model_save', 'log_file', 'config_save', 'vocab_save', 'save_frequency', 'seed', 'device', 'n_batch', 'grad_clipping', 'kl_start', 'kl_w_start', 'kl_w_end', 'lr_start', 'lr_n_period', 'lr_n_restarts', 'lr_n_mult', 'lr_end', 'n_last', 'n_jobs', 'n_workers', 'warm_start')¶
- __module__ = 'gt4sd.training_pipelines.moses.vae.core'¶
- __repr__()¶
Return repr(self).