gt4sd.training_pipelines.torchdrug.core module¶
TorchDrug training utilities.
Summary¶
Classes:
Arguments related to TorchDrug data loading. |
|
Saving arguments related to TorchDrug trainer. |
|
Arguments related to torchDrug trainer. |
|
TorchDrug training pipelines. |
Reference¶
- class TorchDrugTrainingPipeline[source]¶
Bases:
TrainingPipeline
TorchDrug training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
Generic training function for launching a TorchDrug training.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Raises
NotImplementedError – the generic trainer does not implement the pipeline.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'TorchDrug training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.core'¶
- class TorchDrugTrainingArguments(model_path, training_name, epochs=10, batch_size=16, learning_rate=1e-05, log_interval=100, gradient_interval=1, num_worker=0, task=None)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to torchDrug trainer.
- __name__ = 'TorchDrugTrainingArguments'¶
- model_path: str¶
- training_name: str¶
- epochs: int = 10¶
- batch_size: int = 16¶
- learning_rate: float = 1e-05¶
- log_interval: int = 100¶
- gradient_interval: int = 1¶
- num_worker: int = 0¶
- task: Optional[str] = None¶
- __annotations__ = {'batch_size': <class 'int'>, 'epochs': <class 'int'>, 'gradient_interval': <class 'int'>, 'learning_rate': <class 'float'>, 'log_interval': <class 'int'>, 'model_path': <class 'str'>, 'num_worker': <class 'int'>, 'task': typing.Optional[str], 'training_name': <class 'str'>}¶
- __dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=16,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the batch.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'gradient_interval': Field(name='gradient_interval',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient accumulation steps'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=1e-05,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Learning rate used in training.'}),kw_only=False,_field_type=_FIELD), 'log_interval': Field(name='log_interval',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between log intervals.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'num_worker': Field(name='num_worker',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of CPU workers per GPU.'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimization task for goal-driven generation.Currently, TorchDrug only supports `plogp` and `qed`.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to torchDrug trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path, training_name, epochs=10, batch_size=16, learning_rate=1e-05, log_interval=100, gradient_interval=1, num_worker=0, task=None)¶
- __match_args__ = ('model_path', 'training_name', 'epochs', 'batch_size', 'learning_rate', 'log_interval', 'gradient_interval', 'num_worker', 'task')¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.core'¶
- __repr__()¶
Return repr(self).
- class TorchDrugDataArguments(dataset_name, file_path='', dataset_path='/home/runner/.gt4sd/data/torchdrug', target_field='', smiles_field='smiles', transform='lambda x: x', verbose=1, lazy=False, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, no_kekulization=False)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to TorchDrug data loading.
- __name__ = 'TorchDrugDataArguments'¶
- dataset_name: str¶
- file_path: str = ''¶
- dataset_path: str = '/home/runner/.gt4sd/data/torchdrug'¶
- target_field: str = ''¶
- smiles_field: str = 'smiles'¶
- transform: str = 'lambda x: x'¶
- verbose: int = 1¶
- lazy: bool = False¶
- node_feature: str = 'default'¶
- edge_feature: str = 'default'¶
- graph_feature: Optional[str] = None¶
- with_hydrogen: bool = False¶
- no_kekulization: bool = False¶
- __annotations__ = {'dataset_name': <class 'str'>, 'dataset_path': <class 'str'>, 'edge_feature': <class 'str'>, 'file_path': <class 'str'>, 'graph_feature': typing.Optional[str], 'lazy': <class 'bool'>, 'no_kekulization': <class 'bool'>, 'node_feature': <class 'str'>, 'smiles_field': <class 'str'>, 'target_field': <class 'str'>, 'transform': <class 'str'>, 'verbose': <class 'int'>, 'with_hydrogen': <class 'bool'>}¶
- __dataclass_fields__ = {'dataset_name': Field(name='dataset_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Identifier for the dataset. Has to be in dict_keys(['bace', 'bbbp', 'custom', 'cep', 'chembl', 'clintox', 'delaney', 'freesolv', 'hiv', 'lipophilicity', 'malaria', 'moses', 'muv', 'opv', 'pcqm4m', 'pubchem', 'qm8', 'qm9', 'sider', 'tox21', 'toxcast', 'zinc250k', 'zinc2m']). Can either point to one of the predefined TorchDrug datasets or it can be `custom` if the user brings their own dataset. If `custom`, then the arguments `file_path`, `target_field` and `smiles_field` below have to be specified."}),kw_only=False,_field_type=_FIELD), 'dataset_path': Field(name='dataset_path',type=<class 'str'>,default='/home/runner/.gt4sd/data/torchdrug',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the TorchDrug dataset will be stored. This is ignored if `datase_name` is `custom`.'}),kw_only=False,_field_type=_FIELD), 'edge_feature': Field(name='edge_feature',type=<class 'str'>,default='default',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Edge features (or edge feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'file_path': Field(name='file_path',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a path to a .csv file containing the training data."}),kw_only=False,_field_type=_FIELD), 'graph_feature': Field(name='graph_feature',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Graph features (or graph feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'lazy': Field(name='lazy',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If yes, molecules are processed in the dataloader. This is faster for setup but slower at training time.'}),kw_only=False,_field_type=_FIELD), 'no_kekulization': Field(name='no_kekulization',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether SMILES kekulization is used. Per default, it is used.'}),kw_only=False,_field_type=_FIELD), 'node_feature': Field(name='node_feature',type=<class 'str'>,default='default',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Node features (or node feature list) to extract.'}),kw_only=False,_field_type=_FIELD), 'smiles_field': Field(name='smiles_field',type=<class 'str'>,default='smiles',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's the name of the column containing the SMILES strings."}),kw_only=False,_field_type=_FIELD), 'target_field': Field(name='target_field',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a str with name of the column containing the property that can be optimized.Currently TorchDrug only supports `plogp` and `qed`."}),kw_only=False,_field_type=_FIELD), 'transform': Field(name='transform',type=<class 'str'>,default='lambda x: x',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optional data transformation function. Has to be a lambda function (written as a string) that operates on the batch dictionary.See torchdrug docs for details.'}),kw_only=False,_field_type=_FIELD), 'verbose': Field(name='verbose',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Output verbosity level for dataset.'}),kw_only=False,_field_type=_FIELD), 'with_hydrogen': Field(name='with_hydrogen',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hydrogens are stored in molecular graph.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments related to TorchDrug data loading.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(dataset_name, file_path='', dataset_path='/home/runner/.gt4sd/data/torchdrug', target_field='', smiles_field='smiles', transform='lambda x: x', verbose=1, lazy=False, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, no_kekulization=False)¶
- __match_args__ = ('dataset_name', 'file_path', 'dataset_path', 'target_field', 'smiles_field', 'transform', 'verbose', 'lazy', 'node_feature', 'edge_feature', 'graph_feature', 'with_hydrogen', 'no_kekulization')¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.core'¶
- __repr__()¶
Return repr(self).
- class TorchDrugSavingArguments(model_path, training_name, dataset_name, task=None, file_path='', epochs=10)[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to TorchDrug trainer.
- __name__ = 'TorchDrugSavingArguments'¶
- model_path: str¶
- training_name: str¶
- dataset_name: str¶
- task: Optional[str] = None¶
- file_path: str = ''¶
- epochs: int = 10¶
- __annotations__ = {'dataset_name': <class 'str'>, 'epochs': <class 'int'>, 'file_path': <class 'str'>, 'model_path': <class 'str'>, 'task': typing.Optional[str], 'training_name': <class 'str'>}¶
- __dataclass_fields__ = {'dataset_name': Field(name='dataset_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Identifier for the dataset. Has to be in dict_keys(['bace', 'bbbp', 'custom', 'cep', 'chembl', 'clintox', 'delaney', 'freesolv', 'hiv', 'lipophilicity', 'malaria', 'moses', 'muv', 'opv', 'pcqm4m', 'pubchem', 'qm8', 'qm9', 'sider', 'tox21', 'toxcast', 'zinc250k', 'zinc2m']). Can either point to one of the predefined TorchDrug datasets or it can be `custom` if the user brings their own dataset. If `custom`, then the arguments `file_path`, `target_field` and `smiles_field` below have to be specified."}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs.'}),kw_only=False,_field_type=_FIELD), 'file_path': Field(name='file_path',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Ignored unless `datase_name` is `custom`. In that case it's a path to a .csv file containing the training data."}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimization task for goal-driven generation.Currently, TorchDrug only supports `plogp` and `qed`.'}),kw_only=False,_field_type=_FIELD), 'training_name': Field(name='training_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Name used to identify the training.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to TorchDrug trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path, training_name, dataset_name, task=None, file_path='', epochs=10)¶
- __match_args__ = ('model_path', 'training_name', 'dataset_name', 'task', 'file_path', 'epochs')¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.core'¶
- __repr__()¶
Return repr(self).