gt4sd.training_pipelines.moses.core module¶
Moses baselines training utilities.
Summary¶
Classes:
Arguments related to Moses data loading. |
|
Saving arguments related to PaccMann trainer. |
|
Arguments related to Moses trainer. |
|
PyTorch lightining training pipelines. |
Reference¶
- class MosesTrainingPipeline[source]¶
Bases:
TrainingPipeline
PyTorch lightining training pipelines.
- train(training_args, model_args, common_args)[source]¶
Generic training function for GuacaMol Baselines training.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.common_args (
Dict
[str
,Any
]) – common arguments passed to the configuration.
- Raises
NotImplementedError – the generic trainer does not implement the pipeline.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'PyTorch lightining training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.moses.core'¶
- class MosesDataArguments(train_load, val_load)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to Moses data loading.
- __name__ = 'MosesDataArguments'¶
- train_load: str¶
- val_load: str¶
- __annotations__ = {'train_load': <class 'str'>, 'val_load': <class 'str'>}¶
- __dataclass_fields__ = {'train_load': Field(name='train_load',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Input data in csv format used for training.'}),kw_only=False,_field_type=_FIELD), 'val_load': Field(name='val_load',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Input data in csv format used for validation.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to Moses data loading.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(train_load, val_load)¶
- __match_args__ = ('train_load', 'val_load')¶
- __module__ = 'gt4sd.training_pipelines.moses.core'¶
- __repr__()¶
Return repr(self).
- class MosesTrainingArguments(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu')[source]¶
Bases:
TrainingPipelineArguments
Arguments related to Moses trainer.
- __name__ = 'MosesTrainingArguments'¶
- model_save: str¶
- log_file: str¶
- config_save: str¶
- vocab_save: str¶
- save_frequency: int = 1¶
- seed: int = 0¶
- device: str = 'cpu'¶
- __annotations__ = {'config_save': <class 'str'>, 'device': <class 'str'>, 'log_file': <class 'str'>, 'model_save': <class 'str'>, 'save_frequency': <class 'int'>, 'seed': <class 'int'>, 'vocab_save': <class 'str'>}¶
- __dataclass_fields__ = {'config_save': Field(name='config_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path for the config.'}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Device to run: 'cpu' or 'cuda:<device number>'"}),kw_only=False,_field_type=_FIELD), 'log_file': Field(name='log_file',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save the the logs.'}),kw_only=False,_field_type=_FIELD), 'model_save': Field(name='model_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the trained model is saved.'}),kw_only=False,_field_type=_FIELD), 'save_frequency': Field(name='save_frequency',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to save the model.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Seed used for random number generation.'}),kw_only=False,_field_type=_FIELD), 'vocab_save': Field(name='vocab_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to save the model vocabulary.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to Moses trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu')¶
- __match_args__ = ('model_save', 'log_file', 'config_save', 'vocab_save', 'save_frequency', 'seed', 'device')¶
- __module__ = 'gt4sd.training_pipelines.moses.core'¶
- __repr__()¶
Return repr(self).
- class MosesSavingArguments(model_path, config_path, vocab_path)[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to PaccMann trainer.
- __name__ = 'MosesSavingArguments'¶
- model_path: str¶
- config_path: str¶
- vocab_path: str¶
- __annotations__ = {'config_path': <class 'str'>, 'model_path': <class 'str'>, 'vocab_path': <class 'str'>}¶
- __dataclass_fields__ = {'config_path': Field(name='config_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the config is stored.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model is stored.'}),kw_only=False,_field_type=_FIELD), 'vocab_path': Field(name='vocab_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the vocab is stored.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to PaccMann trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path, config_path, vocab_path)¶
- __match_args__ = ('model_path', 'config_path', 'vocab_path')¶
- __module__ = 'gt4sd.training_pipelines.moses.core'¶
- __repr__()¶
Return repr(self).