gt4sd.training_pipelines.paccmann.core module

PaccMann training utilities.

Summary

Classes:

PaccMannDataArguments

Arguments related to PaccMann data loading.

PaccMannSavingArguments

Saving arguments related to PaccMann trainer.

PaccMannTrainingArguments

Arguments related to PaccMann trainer.

PaccMannTrainingPipeline

PyTorch lightining training pipelines.

Reference

class PaccMannTrainingPipeline[source]

Bases: TrainingPipeline

PyTorch lightining training pipelines.

train(training_args, model_args, dataset_args)[source]

Generic training function for PaccMann training.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Raises

NotImplementedError – the generic trainer does not implement the pipeline.

Return type

None

__annotations__ = {}
__doc__ = 'PyTorch lightining training pipelines.'
__module__ = 'gt4sd.training_pipelines.paccmann.core'
class PaccMannTrainingArguments(model_path, training_name, checkpoint_path=None, epochs=50, batch_size=256, learning_rate=0.0005, optimizer='adam', log_interval=100, save_interval=1000, eval_interval=500)[source]

Bases: TrainingPipelineArguments

Arguments related to PaccMann trainer.

__name__ = 'PaccMannTrainingArguments'
model_path: str
training_name: str
checkpoint_path: Optional[str] = None
epochs: int = 50
batch_size: int = 256
learning_rate: float = 0.0005
optimizer: str = 'adam'
log_interval: int = 100
save_interval: int = 1000
eval_interval: int = 500
__annotations__ = {'batch_size': <class 'int'>, 'checkpoint_path': typing.Optional[str], 'epochs': <class 'int'>, 'eval_interval': <class 'int'>, 'learning_rate': <class 'float'>, 'log_interval': <class 'int'>, 'model_path': <class 'str'>, 'optimizer': <class 'str'>, 'save_interval': <class 'int'>, 'training_name': <class 'str'>}
__dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the batch.'}),kw_only=False,_field_type=_FIELD), 'checkpoint_path': Field(name='checkpoint_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to model checkpoint for weights initialization. Leave None if you want to train a model from scratch'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'eval_interval': Field(name='eval_interval',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between evaluation intervals.'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=0.0005,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Learning rate used in training.'}),kw_only=False,_field_type=_FIELD), 'log_interval': Field(name='log_interval',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between log intervals.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'optimizer': Field(name='optimizer',type=<class 'str'>,default='adam',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer used during training.'}),kw_only=False,_field_type=_FIELD), 'save_interval': Field(name='save_interval',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between model save intervals.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to PaccMann trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path, training_name, checkpoint_path=None, epochs=50, batch_size=256, learning_rate=0.0005, optimizer='adam', log_interval=100, save_interval=1000, eval_interval=500)
__match_args__ = ('model_path', 'training_name', 'checkpoint_path', 'epochs', 'batch_size', 'learning_rate', 'optimizer', 'log_interval', 'save_interval', 'eval_interval')
__module__ = 'gt4sd.training_pipelines.paccmann.core'
__repr__()

Return repr(self).

class PaccMannDataArguments(train_smiles_filepath, test_smiles_filepath, smiles_language_filepath='none', add_start_stop_token=True, selfies=True, num_workers=0, pin_memory=False, augment_smiles=False, canonical=False, kekulize=False, all_bonds_explicit=False, all_hs_explicit=False, remove_bonddir=False, remove_chirality=False)[source]

Bases: TrainingPipelineArguments

Arguments related to PaccMann data loading.

__name__ = 'PaccMannDataArguments'
train_smiles_filepath: str
smiles_language_filepath: str = 'none'
add_start_stop_token: bool = True
selfies: bool = True
num_workers: int = 0
pin_memory: bool = False
augment_smiles: bool = False
canonical: bool = False
kekulize: bool = False
all_bonds_explicit: bool = False
all_hs_explicit: bool = False
remove_bonddir: bool = False
remove_chirality: bool = False
__annotations__ = {'add_start_stop_token': <class 'bool'>, 'all_bonds_explicit': <class 'bool'>, 'all_hs_explicit': <class 'bool'>, 'augment_smiles': <class 'bool'>, 'canonical': <class 'bool'>, 'kekulize': <class 'bool'>, 'num_workers': <class 'int'>, 'pin_memory': <class 'bool'>, 'remove_bonddir': <class 'bool'>, 'remove_chirality': <class 'bool'>, 'selfies': <class 'bool'>, 'smiles_language_filepath': <class 'str'>, 'test_smiles_filepath': <class 'str'>, 'train_smiles_filepath': <class 'str'>}
__dataclass_fields__ = {'add_start_stop_token': Field(name='add_start_stop_token',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether start and stop token should be added.'}),kw_only=False,_field_type=_FIELD), 'all_bonds_explicit': Field(name='all_bonds_explicit',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether all bonds are explicit.'}),kw_only=False,_field_type=_FIELD), 'all_hs_explicit': Field(name='all_hs_explicit',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether all hydrogens are explicit.'}),kw_only=False,_field_type=_FIELD), 'augment_smiles': Field(name='augment_smiles',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES augumentation is used.'}),kw_only=False,_field_type=_FIELD), 'canonical': Field(name='canonical',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES canonicalization is used.'}),kw_only=False,_field_type=_FIELD), 'kekulize': Field(name='kekulize',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES kekulization is used.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of workers used in data loading.'}),kw_only=False,_field_type=_FIELD), 'pin_memory': Field(name='pin_memory',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether memory in the data loader is pinned.'}),kw_only=False,_field_type=_FIELD), 'remove_bonddir': Field(name='remove_bonddir',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Remove bond directionality.'}),kw_only=False,_field_type=_FIELD), 'remove_chirality': Field(name='remove_chirality',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Remove chirality.'}),kw_only=False,_field_type=_FIELD), 'selfies': Field(name='selfies',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SELFIES representations are used.'}),kw_only=False,_field_type=_FIELD), 'smiles_language_filepath': Field(name='smiles_language_filepath',type=<class 'str'>,default='none',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optional SMILES language file.'}),kw_only=False,_field_type=_FIELD), 'test_smiles_filepath': Field(name='test_smiles_filepath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Testing file containing SMILES in .smi format.'}),kw_only=False,_field_type=_FIELD), 'train_smiles_filepath': Field(name='train_smiles_filepath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training file containing SMILES in .smi format.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to PaccMann data loading.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(train_smiles_filepath, test_smiles_filepath, smiles_language_filepath='none', add_start_stop_token=True, selfies=True, num_workers=0, pin_memory=False, augment_smiles=False, canonical=False, kekulize=False, all_bonds_explicit=False, all_hs_explicit=False, remove_bonddir=False, remove_chirality=False)
__match_args__ = ('train_smiles_filepath', 'test_smiles_filepath', 'smiles_language_filepath', 'add_start_stop_token', 'selfies', 'num_workers', 'pin_memory', 'augment_smiles', 'canonical', 'kekulize', 'all_bonds_explicit', 'all_hs_explicit', 'remove_bonddir', 'remove_chirality')
__module__ = 'gt4sd.training_pipelines.paccmann.core'
__repr__()

Return repr(self).

class PaccMannSavingArguments(model_path, training_name)[source]

Bases: TrainingPipelineArguments

Saving arguments related to PaccMann trainer.

__name__ = 'PaccMannSavingArguments'
model_path: str
training_name: str
__annotations__ = {'model_path': <class 'str'>, 'training_name': <class 'str'>}
__dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to PaccMann trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path, training_name)
__match_args__ = ('model_path', 'training_name')
__module__ = 'gt4sd.training_pipelines.paccmann.core'
__repr__()

Return repr(self).