gt4sd.training_pipelines.torchdrug.gcpn.core module¶
TorchDrug GCPN training utilities.
Summary¶
Classes:
Arguments pertaining to model instantiation. |
|
TorchDrug GCPN training pipelines. |
Reference¶
- class TorchDrugGCPNTrainingPipeline[source]¶
Bases:
TorchDrugTrainingPipeline
TorchDrug GCPN training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
- Generic training function for training a Graph Convolutional Policy Network
- (GCPN) model. For details see:
You, J. et al. (2018). Graph convolutional policy network for goal- directed molecular graph generation. Advances in neural information processing systems, 31.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'TorchDrug GCPN training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.gcpn.core'¶
- class TorchDrugGCPNModelArguments(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, readout='sum', max_edge_unroll=None, max_node=None, criterion="{'nll': 1.0}", hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)[source]¶
Bases:
TrainingPipelineArguments
Arguments pertaining to model instantiation.
- __name__ = 'TorchDrugGCPNModelArguments'¶
- batch_norm: bool = False¶
- edge_input_dim: Optional[int] = None¶
- short_cut: bool = False¶
- activation: str = 'relu'¶
- readout: str = 'sum'¶
- max_edge_unroll: Optional[int] = None¶
- max_node: Optional[int] = None¶
- criterion: str = "{'nll': 1.0}"¶
- agent_update_interval: int = 10¶
- gamma: float = 0.9¶
- reward_temperature: float = 1.0¶
- baseline_momentum: float = 0.9¶
- __annotations__ = {'activation': <class 'str'>, 'agent_update_interval': <class 'int'>, 'baseline_momentum': <class 'float'>, 'batch_norm': <class 'bool'>, 'concat_hidden': <class 'bool'>, 'criterion': <class 'str'>, 'edge_input_dim': typing.Optional[int], 'gamma': <class 'float'>, 'hidden_dim_mlp': <class 'int'>, 'hidden_dims': <class 'str'>, 'max_edge_unroll': typing.Optional[int], 'max_node': typing.Optional[int], 'readout': <class 'str'>, 'reward_temperature': <class 'float'>, 'short_cut': <class 'bool'>}¶
- __dataclass_fields__ = {'activation': Field(name='activation',type=<class 'str'>,default='relu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Activation function for RGCN'}),kw_only=False,_field_type=_FIELD), 'agent_update_interval': Field(name='agent_update_interval',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Update the agent every n batches (similar to gradient accumulation)'}),kw_only=False,_field_type=_FIELD), 'baseline_momentum': Field(name='baseline_momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum for value function baseline'}),kw_only=False,_field_type=_FIELD), 'batch_norm': Field(name='batch_norm',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses batch normalization'}),kw_only=False,_field_type=_FIELD), 'concat_hidden': Field(name='concat_hidden',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hidden representations from all layers are concatenated'}),kw_only=False,_field_type=_FIELD), 'criterion': Field(name='criterion',type=<class 'str'>,default="{'nll': 1.0}",default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'training criterion. Available criteria are `nll` and `ppo` for regular training and property optimization respectively. If dict, the keys are criterions and values are the corresponding weights. If list, both criteria are used with equal weights.'}),kw_only=False,_field_type=_FIELD), 'edge_input_dim': Field(name='edge_input_dim',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimension of edge features'}),kw_only=False,_field_type=_FIELD), 'gamma': Field(name='gamma',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Reward discount rate'}),kw_only=False,_field_type=_FIELD), 'hidden_dim_mlp': Field(name='hidden_dim_mlp',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Hidden size of GCPN internal MLP.'}),kw_only=False,_field_type=_FIELD), 'hidden_dims': Field(name='hidden_dims',type=<class 'str'>,default='[128, 128]',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimensionality of each hidden layer'}),kw_only=False,_field_type=_FIELD), 'max_edge_unroll': Field(name='max_edge_unroll',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max node id difference. Inferred from training data if not provided'}),kw_only=False,_field_type=_FIELD), 'max_node': Field(name='max_node',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max number of node. Inferred from training data if not provided.'}),kw_only=False,_field_type=_FIELD), 'readout': Field(name='readout',type=<class 'str'>,default='sum',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'RGCN Readout function. Either `sum` or `mean`'}),kw_only=False,_field_type=_FIELD), 'reward_temperature': Field(name='reward_temperature',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the reward (larger -> higher mean reward)lower -> higher maximal reward'}),kw_only=False,_field_type=_FIELD), 'short_cut': Field(name='short_cut',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses a short cut'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments pertaining to model instantiation.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, readout='sum', max_edge_unroll=None, max_node=None, criterion="{'nll': 1.0}", hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)¶
- __match_args__ = ('hidden_dims', 'batch_norm', 'edge_input_dim', 'short_cut', 'activation', 'concat_hidden', 'readout', 'max_edge_unroll', 'max_node', 'criterion', 'hidden_dim_mlp', 'agent_update_interval', 'gamma', 'reward_temperature', 'baseline_momentum')¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.gcpn.core'¶
- __repr__()¶
Return repr(self).