gt4sd.training_pipelines.torchdrug.core module

TorchDrug training utilities.

Summary

Classes:

TorchDrugDataArguments

Arguments related to TorchDrug data loading.

TorchDrugSavingArguments

Saving arguments related to TorchDrug trainer.

TorchDrugTrainingArguments

Arguments related to torchDrug trainer.

TorchDrugTrainingPipeline

TorchDrug training pipelines.

Reference

class TorchDrugTrainingPipeline[source]

Bases: TrainingPipeline

TorchDrug training pipelines.

train(training_args, model_args, dataset_args)[source]

Generic training function for launching a TorchDrug training.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Raises

NotImplementedError – the generic trainer does not implement the pipeline.

Return type

None

__annotations__ = {}
__doc__ = 'TorchDrug training pipelines.'
__module__ = 'gt4sd.training_pipelines.torchdrug.core'
class TorchDrugTrainingArguments(model_path, training_name, epochs=10, batch_size=16, learning_rate=1e-05, log_interval=100, gradient_interval=1, num_worker=0, task=None)[source]

Bases: TrainingPipelineArguments

Arguments related to torchDrug trainer.

__name__ = 'TorchDrugTrainingArguments'
model_path: str
training_name: str
epochs: int = 10
batch_size: int = 16
learning_rate: float = 1e-05
log_interval: int = 100
gradient_interval: int = 1
num_worker: int = 0
task: Optional[str] = None
__annotations__ = {'batch_size': <class 'int'>, 'epochs': <class 'int'>, 'gradient_interval': <class 'int'>, 'learning_rate': <class 'float'>, 'log_interval': <class 'int'>, 'model_path': <class 'str'>, 'num_worker': <class 'int'>, 'task': typing.Optional[str], 'training_name': <class 'str'>}
__dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=16,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the batch.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'gradient_interval': Field(name='gradient_interval',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient accumulation steps'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=1e-05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Learning rate used in training.'}),kw_only=False,_field_type=_FIELD), 'log_interval': Field(name='log_interval',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between log intervals.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'num_worker': Field(name='num_worker',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of CPU workers per GPU.'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimization task for goal-driven generation.Currently, TorchDrug only supports `plogp` and `qed`.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to torchDrug trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path, training_name, epochs=10, batch_size=16, learning_rate=1e-05, log_interval=100, gradient_interval=1, num_worker=0, task=None)
__match_args__ = ('model_path', 'training_name', 'epochs', 'batch_size', 'learning_rate', 'log_interval', 'gradient_interval', 'num_worker', 'task')
__module__ = 'gt4sd.training_pipelines.torchdrug.core'
__repr__()

Return repr(self).

class TorchDrugDataArguments(dataset_name, file_path='', dataset_path='/home/runner/.gt4sd/data/torchdrug', target_field='', smiles_field='smiles', transform='lambda x: x', verbose=1, lazy=False, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, no_kekulization=False)[source]

Bases: TrainingPipelineArguments

Arguments related to TorchDrug data loading.

__name__ = 'TorchDrugDataArguments'
dataset_name: str
file_path: str = ''
dataset_path: str = '/home/runner/.gt4sd/data/torchdrug'
target_field: str = ''
smiles_field: str = 'smiles'
transform: str = 'lambda x: x'
verbose: int = 1
lazy: bool = False
node_feature: str = 'default'
edge_feature: str = 'default'
graph_feature: Optional[str] = None
with_hydrogen: bool = False
no_kekulization: bool = False
__annotations__ = {'dataset_name': <class 'str'>, 'dataset_path': <class 'str'>, 'edge_feature': <class 'str'>, 'file_path': <class 'str'>, 'graph_feature': typing.Optional[str], 'lazy': <class 'bool'>, 'no_kekulization': <class 'bool'>, 'node_feature': <class 'str'>, 'smiles_field': <class 'str'>, 'target_field': <class 'str'>, 'transform': <class 'str'>, 'verbose': <class 'int'>, 'with_hydrogen': <class 'bool'>}
__dataclass_fields__ = {'dataset_name': Field(name='dataset_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Identifier for the dataset. Has to be in dict_keys(['bace', 'bbbp', 'custom', 'cep', 'chembl', 'clintox', 'delaney', 'freesolv', 'hiv', 'lipophilicity', 'malaria', 'moses', 'muv', 'opv', 'pcqm4m', 'pubchem', 'qm8', 'qm9', 'sider', 'tox21', 'toxcast', 'zinc250k', 'zinc2m']). Can either point to one of the predefined TorchDrug datasets or it can be `custom` if the user brings their own dataset. If `custom`, then the arguments `file_path`, `target_field` and `smiles_field` below have to be specified."}),kw_only=False,_field_type=_FIELD), 'dataset_path': Field(name='dataset_path',type=<class 'str'>,default='/home/runner/.gt4sd/data/torchdrug',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the TorchDrug dataset will be stored. This is ignored if `datase_name` is `custom`.'}),kw_only=False,_field_type=_FIELD), 'edge_feature': Field(name='edge_feature',type=<class 'str'>,default='default',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Edge features (or edge feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'file_path': Field(name='file_path',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a path to a .csv file containing the training data."}),kw_only=False,_field_type=_FIELD), 'graph_feature': Field(name='graph_feature',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Graph features (or graph feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'lazy': Field(name='lazy',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If yes, molecules are processed in the dataloader. This is faster for setup but slower at training time.'}),kw_only=False,_field_type=_FIELD), 'no_kekulization': Field(name='no_kekulization',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES kekulization is used. Per default, it is used.'}),kw_only=False,_field_type=_FIELD), 'node_feature': Field(name='node_feature',type=<class 'str'>,default='default',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Node features (or node feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'smiles_field': Field(name='smiles_field',type=<class 'str'>,default='smiles',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's the name of the column containing the SMILES strings."}),kw_only=False,_field_type=_FIELD), 'target_field': Field(name='target_field',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a str with name of the column containing the property that can be optimized.Currently TorchDrug only supports `plogp` and `qed`."}),kw_only=False,_field_type=_FIELD), 'transform': Field(name='transform',type=<class 'str'>,default='lambda x: x',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optional data transformation function. Has to be a lambda function (written as a string) that operates on the batch dictionary.See torchdrug docs for details.'}),kw_only=False,_field_type=_FIELD), 'verbose': Field(name='verbose',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Output verbosity level for dataset.'}),kw_only=False,_field_type=_FIELD), 'with_hydrogen': Field(name='with_hydrogen',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hydrogens are stored in molecular graph.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments related to TorchDrug data loading.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(dataset_name, file_path='', dataset_path='/home/runner/.gt4sd/data/torchdrug', target_field='', smiles_field='smiles', transform='lambda x: x', verbose=1, lazy=False, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, no_kekulization=False)
__match_args__ = ('dataset_name', 'file_path', 'dataset_path', 'target_field', 'smiles_field', 'transform', 'verbose', 'lazy', 'node_feature', 'edge_feature', 'graph_feature', 'with_hydrogen', 'no_kekulization')
__module__ = 'gt4sd.training_pipelines.torchdrug.core'
__repr__()

Return repr(self).

class TorchDrugSavingArguments(model_path, training_name, dataset_name, task=None, file_path='', epochs=10)[source]

Bases: TrainingPipelineArguments

Saving arguments related to TorchDrug trainer.

__name__ = 'TorchDrugSavingArguments'
model_path: str
training_name: str
dataset_name: str
task: Optional[str] = None
file_path: str = ''
epochs: int = 10
__annotations__ = {'dataset_name': <class 'str'>, 'epochs': <class 'int'>, 'file_path': <class 'str'>, 'model_path': <class 'str'>, 'task': typing.Optional[str], 'training_name': <class 'str'>}
__dataclass_fields__ = {'dataset_name': Field(name='dataset_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Identifier for the dataset. Has to be in dict_keys(['bace', 'bbbp', 'custom', 'cep', 'chembl', 'clintox', 'delaney', 'freesolv', 'hiv', 'lipophilicity', 'malaria', 'moses', 'muv', 'opv', 'pcqm4m', 'pubchem', 'qm8', 'qm9', 'sider', 'tox21', 'toxcast', 'zinc250k', 'zinc2m']). Can either point to one of the predefined TorchDrug datasets or it can be `custom` if the user brings their own dataset. If `custom`, then the arguments `file_path`, `target_field` and `smiles_field` below have to be specified."}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'file_path': Field(name='file_path',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a path to a .csv file containing the training data."}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimization task for goal-driven generation.Currently, TorchDrug only supports `plogp` and `qed`.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to TorchDrug trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path, training_name, dataset_name, task=None, file_path='', epochs=10)
__match_args__ = ('model_path', 'training_name', 'dataset_name', 'task', 'file_path', 'epochs')
__module__ = 'gt4sd.training_pipelines.torchdrug.core'
__repr__()

Return repr(self).