gt4sd.training_pipelines.moses.core module

Moses baselines training utilities.

Summary

Classes:

MosesDataArguments

Arguments related to Moses data loading.

MosesSavingArguments

Saving arguments related to PaccMann trainer.

MosesTrainingArguments

Arguments related to Moses trainer.

MosesTrainingPipeline

PyTorch lightining training pipelines.

Reference

class MosesTrainingPipeline[source]

Bases: TrainingPipeline

PyTorch lightining training pipelines.

train(training_args, model_args, common_args)[source]

Generic training function for GuacaMol Baselines training.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • common_args (Dict[str, Any]) – common arguments passed to the configuration.

Raises

NotImplementedError – the generic trainer does not implement the pipeline.

Return type

None

__annotations__ = {}
__doc__ = 'PyTorch lightining training pipelines.'
__module__ = 'gt4sd.training_pipelines.moses.core'
class MosesDataArguments(train_load, val_load)[source]

Bases: TrainingPipelineArguments

Arguments related to Moses data loading.

__name__ = 'MosesDataArguments'
train_load: str
val_load: str
__annotations__ = {'train_load': <class 'str'>, 'val_load': <class 'str'>}
__dataclass_fields__ = {'train_load': Field(name='train_load',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Input data in csv format used for training.'}),kw_only=False,_field_type=_FIELD), 'val_load': Field(name='val_load',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Input data in csv format used for validation.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to Moses data loading.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(train_load, val_load)
__match_args__ = ('train_load', 'val_load')
__module__ = 'gt4sd.training_pipelines.moses.core'
__repr__()

Return repr(self).

class MosesTrainingArguments(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu')[source]

Bases: TrainingPipelineArguments

Arguments related to Moses trainer.

__name__ = 'MosesTrainingArguments'
model_save: str
log_file: str
config_save: str
vocab_save: str
save_frequency: int = 1
seed: int = 0
device: str = 'cpu'
__annotations__ = {'config_save': <class 'str'>, 'device': <class 'str'>, 'log_file': <class 'str'>, 'model_save': <class 'str'>, 'save_frequency': <class 'int'>, 'seed': <class 'int'>, 'vocab_save': <class 'str'>}
__dataclass_fields__ = {'config_save': Field(name='config_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path for the config.'}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Device to run: 'cpu' or 'cuda:<device number>'"}),kw_only=False,_field_type=_FIELD), 'log_file': Field(name='log_file',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save the the logs.'}),kw_only=False,_field_type=_FIELD), 'model_save': Field(name='model_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the trained model is saved.'}),kw_only=False,_field_type=_FIELD), 'save_frequency': Field(name='save_frequency',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to save the model.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Seed used for random number generation.'}),kw_only=False,_field_type=_FIELD), 'vocab_save': Field(name='vocab_save',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to save the model vocabulary.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to Moses trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_save, log_file, config_save, vocab_save, save_frequency=1, seed=0, device='cpu')
__match_args__ = ('model_save', 'log_file', 'config_save', 'vocab_save', 'save_frequency', 'seed', 'device')
__module__ = 'gt4sd.training_pipelines.moses.core'
__repr__()

Return repr(self).

class MosesSavingArguments(model_path, config_path, vocab_path)[source]

Bases: TrainingPipelineArguments

Saving arguments related to PaccMann trainer.

__name__ = 'MosesSavingArguments'
model_path: str
config_path: str
vocab_path: str
__annotations__ = {'config_path': <class 'str'>, 'model_path': <class 'str'>, 'vocab_path': <class 'str'>}
__dataclass_fields__ = {'config_path': Field(name='config_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the config is stored.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model is stored.'}),kw_only=False,_field_type=_FIELD), 'vocab_path': Field(name='vocab_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the vocab is stored.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to PaccMann trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path, config_path, vocab_path)
__match_args__ = ('model_path', 'config_path', 'vocab_path')
__module__ = 'gt4sd.training_pipelines.moses.core'
__repr__()

Return repr(self).