gt4sd.training_pipelines.paccmann.core module¶
PaccMann training utilities.
Summary¶
Classes:
Arguments related to PaccMann data loading. |
|
Saving arguments related to PaccMann trainer. |
|
Arguments related to PaccMann trainer. |
|
PyTorch lightining training pipelines. |
Reference¶
- class PaccMannTrainingPipeline[source]¶
Bases:
TrainingPipeline
PyTorch lightining training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
Generic training function for PaccMann training.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Raises
NotImplementedError – the generic trainer does not implement the pipeline.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'PyTorch lightining training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.paccmann.core'¶
- class PaccMannTrainingArguments(model_path, training_name, checkpoint_path=None, epochs=50, batch_size=256, learning_rate=0.0005, optimizer='adam', log_interval=100, save_interval=1000, eval_interval=500)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to PaccMann trainer.
- __name__ = 'PaccMannTrainingArguments'¶
- model_path: str¶
- training_name: str¶
- checkpoint_path: Optional[str] = None¶
- epochs: int = 50¶
- batch_size: int = 256¶
- learning_rate: float = 0.0005¶
- optimizer: str = 'adam'¶
- log_interval: int = 100¶
- save_interval: int = 1000¶
- eval_interval: int = 500¶
- __annotations__ = {'batch_size': <class 'int'>, 'checkpoint_path': typing.Optional[str], 'epochs': <class 'int'>, 'eval_interval': <class 'int'>, 'learning_rate': <class 'float'>, 'log_interval': <class 'int'>, 'model_path': <class 'str'>, 'optimizer': <class 'str'>, 'save_interval': <class 'int'>, 'training_name': <class 'str'>}¶
- __dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the batch.'}),kw_only=False,_field_type=_FIELD), 'checkpoint_path': Field(name='checkpoint_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to model checkpoint for weights initialization. Leave None if you want to train a model from scratch'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'eval_interval': Field(name='eval_interval',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between evaluation intervals.'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=0.0005,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Learning rate used in training.'}),kw_only=False,_field_type=_FIELD), 'log_interval': Field(name='log_interval',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between log intervals.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'optimizer': Field(name='optimizer',type=<class 'str'>,default='adam',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer used during training.'}),kw_only=False,_field_type=_FIELD), 'save_interval': Field(name='save_interval',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between model save intervals.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to PaccMann trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path, training_name, checkpoint_path=None, epochs=50, batch_size=256, learning_rate=0.0005, optimizer='adam', log_interval=100, save_interval=1000, eval_interval=500)¶
- __match_args__ = ('model_path', 'training_name', 'checkpoint_path', 'epochs', 'batch_size', 'learning_rate', 'optimizer', 'log_interval', 'save_interval', 'eval_interval')¶
- __module__ = 'gt4sd.training_pipelines.paccmann.core'¶
- __repr__()¶
Return repr(self).
- class PaccMannDataArguments(train_smiles_filepath, test_smiles_filepath, smiles_language_filepath='none', add_start_stop_token=True, selfies=True, num_workers=0, pin_memory=False, augment_smiles=False, canonical=False, kekulize=False, all_bonds_explicit=False, all_hs_explicit=False, remove_bonddir=False, remove_chirality=False)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to PaccMann data loading.
- __name__ = 'PaccMannDataArguments'¶
- train_smiles_filepath: str¶
- smiles_language_filepath: str = 'none'¶
- add_start_stop_token: bool = True¶
- selfies: bool = True¶
- num_workers: int = 0¶
- pin_memory: bool = False¶
- augment_smiles: bool = False¶
- canonical: bool = False¶
- kekulize: bool = False¶
- all_bonds_explicit: bool = False¶
- all_hs_explicit: bool = False¶
- remove_bonddir: bool = False¶
- remove_chirality: bool = False¶
- __annotations__ = {'add_start_stop_token': <class 'bool'>, 'all_bonds_explicit': <class 'bool'>, 'all_hs_explicit': <class 'bool'>, 'augment_smiles': <class 'bool'>, 'canonical': <class 'bool'>, 'kekulize': <class 'bool'>, 'num_workers': <class 'int'>, 'pin_memory': <class 'bool'>, 'remove_bonddir': <class 'bool'>, 'remove_chirality': <class 'bool'>, 'selfies': <class 'bool'>, 'smiles_language_filepath': <class 'str'>, 'test_smiles_filepath': <class 'str'>, 'train_smiles_filepath': <class 'str'>}¶
- __dataclass_fields__ = {'add_start_stop_token': Field(name='add_start_stop_token',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether start and stop token should be added.'}),kw_only=False,_field_type=_FIELD), 'all_bonds_explicit': Field(name='all_bonds_explicit',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether all bonds are explicit.'}),kw_only=False,_field_type=_FIELD), 'all_hs_explicit': Field(name='all_hs_explicit',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether all hydrogens are explicit.'}),kw_only=False,_field_type=_FIELD), 'augment_smiles': Field(name='augment_smiles',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES augumentation is used.'}),kw_only=False,_field_type=_FIELD), 'canonical': Field(name='canonical',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES canonicalization is used.'}),kw_only=False,_field_type=_FIELD), 'kekulize': Field(name='kekulize',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES kekulization is used.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of workers used in data loading.'}),kw_only=False,_field_type=_FIELD), 'pin_memory': Field(name='pin_memory',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether memory in the data loader is pinned.'}),kw_only=False,_field_type=_FIELD), 'remove_bonddir': Field(name='remove_bonddir',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Remove bond directionality.'}),kw_only=False,_field_type=_FIELD), 'remove_chirality': Field(name='remove_chirality',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Remove chirality.'}),kw_only=False,_field_type=_FIELD), 'selfies': Field(name='selfies',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SELFIES representations are used.'}),kw_only=False,_field_type=_FIELD), 'smiles_language_filepath': Field(name='smiles_language_filepath',type=<class 'str'>,default='none',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optional SMILES language file.'}),kw_only=False,_field_type=_FIELD), 'test_smiles_filepath': Field(name='test_smiles_filepath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Testing file containing SMILES in .smi format.'}),kw_only=False,_field_type=_FIELD), 'train_smiles_filepath': Field(name='train_smiles_filepath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training file containing SMILES in .smi format.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to PaccMann data loading.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(train_smiles_filepath, test_smiles_filepath, smiles_language_filepath='none', add_start_stop_token=True, selfies=True, num_workers=0, pin_memory=False, augment_smiles=False, canonical=False, kekulize=False, all_bonds_explicit=False, all_hs_explicit=False, remove_bonddir=False, remove_chirality=False)¶
- __match_args__ = ('train_smiles_filepath', 'test_smiles_filepath', 'smiles_language_filepath', 'add_start_stop_token', 'selfies', 'num_workers', 'pin_memory', 'augment_smiles', 'canonical', 'kekulize', 'all_bonds_explicit', 'all_hs_explicit', 'remove_bonddir', 'remove_chirality')¶
- __module__ = 'gt4sd.training_pipelines.paccmann.core'¶
- __repr__()¶
Return repr(self).
- class PaccMannSavingArguments(model_path, training_name)[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to PaccMann trainer.
- __name__ = 'PaccMannSavingArguments'¶
- model_path: str¶
- training_name: str¶
- __annotations__ = {'model_path': <class 'str'>, 'training_name': <class 'str'>}¶
- __dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to PaccMann trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path, training_name)¶
- __match_args__ = ('model_path', 'training_name')¶
- __module__ = 'gt4sd.training_pipelines.paccmann.core'¶
- __repr__()¶
Return repr(self).