gt4sd.training_pipelines.torchdrug.gcpn.core module

TorchDrug GCPN training utilities.

Summary

Classes:

TorchDrugGCPNModelArguments

Arguments pertaining to model instantiation.

TorchDrugGCPNTrainingPipeline

TorchDrug GCPN training pipelines.

Reference

class TorchDrugGCPNTrainingPipeline[source]

Bases: TorchDrugTrainingPipeline

TorchDrug GCPN training pipelines.

train(training_args, model_args, dataset_args)[source]
Generic training function for training a Graph Convolutional Policy Network
(GCPN) model. For details see:

You, J. et al. (2018). Graph convolutional policy network for goal- directed molecular graph generation. Advances in neural information processing systems, 31.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

None

__annotations__ = {}
__doc__ = 'TorchDrug GCPN training pipelines.'
__module__ = 'gt4sd.training_pipelines.torchdrug.gcpn.core'
class TorchDrugGCPNModelArguments(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, readout='sum', max_edge_unroll=None, max_node=None, criterion="{'nll': 1.0}", hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)[source]

Bases: TrainingPipelineArguments

Arguments pertaining to model instantiation.

__name__ = 'TorchDrugGCPNModelArguments'
hidden_dims: str = '[128, 128]'
batch_norm: bool = False
edge_input_dim: Optional[int] = None
short_cut: bool = False
activation: str = 'relu'
concat_hidden: bool = False
readout: str = 'sum'
max_edge_unroll: Optional[int] = None
max_node: Optional[int] = None
criterion: str = "{'nll': 1.0}"
hidden_dim_mlp: int = 128
agent_update_interval: int = 10
gamma: float = 0.9
reward_temperature: float = 1.0
baseline_momentum: float = 0.9
__annotations__ = {'activation': <class 'str'>, 'agent_update_interval': <class 'int'>, 'baseline_momentum': <class 'float'>, 'batch_norm': <class 'bool'>, 'concat_hidden': <class 'bool'>, 'criterion': <class 'str'>, 'edge_input_dim': typing.Optional[int], 'gamma': <class 'float'>, 'hidden_dim_mlp': <class 'int'>, 'hidden_dims': <class 'str'>, 'max_edge_unroll': typing.Optional[int], 'max_node': typing.Optional[int], 'readout': <class 'str'>, 'reward_temperature': <class 'float'>, 'short_cut': <class 'bool'>}
__dataclass_fields__ = {'activation': Field(name='activation',type=<class 'str'>,default='relu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Activation function for RGCN'}),kw_only=False,_field_type=_FIELD), 'agent_update_interval': Field(name='agent_update_interval',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Update the agent every n batches (similar to gradient accumulation)'}),kw_only=False,_field_type=_FIELD), 'baseline_momentum': Field(name='baseline_momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum for value function baseline'}),kw_only=False,_field_type=_FIELD), 'batch_norm': Field(name='batch_norm',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses batch normalization'}),kw_only=False,_field_type=_FIELD), 'concat_hidden': Field(name='concat_hidden',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hidden representations from all layers are concatenated'}),kw_only=False,_field_type=_FIELD), 'criterion': Field(name='criterion',type=<class 'str'>,default="{'nll': 1.0}",default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'training criterion. Available criteria are `nll` and `ppo` for regular training and property optimization respectively. If dict, the keys are criterions and values are the corresponding weights. If list, both criteria are used with equal weights.'}),kw_only=False,_field_type=_FIELD), 'edge_input_dim': Field(name='edge_input_dim',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimension of edge features'}),kw_only=False,_field_type=_FIELD), 'gamma': Field(name='gamma',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Reward discount rate'}),kw_only=False,_field_type=_FIELD), 'hidden_dim_mlp': Field(name='hidden_dim_mlp',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Hidden size of GCPN internal MLP.'}),kw_only=False,_field_type=_FIELD), 'hidden_dims': Field(name='hidden_dims',type=<class 'str'>,default='[128, 128]',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimensionality of each hidden layer'}),kw_only=False,_field_type=_FIELD), 'max_edge_unroll': Field(name='max_edge_unroll',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max node id difference. Inferred from training data if not provided'}),kw_only=False,_field_type=_FIELD), 'max_node': Field(name='max_node',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max number of node. Inferred from training data if not provided.'}),kw_only=False,_field_type=_FIELD), 'readout': Field(name='readout',type=<class 'str'>,default='sum',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'RGCN Readout function. Either `sum` or `mean`'}),kw_only=False,_field_type=_FIELD), 'reward_temperature': Field(name='reward_temperature',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the reward (larger -> higher mean reward)lower -> higher maximal reward'}),kw_only=False,_field_type=_FIELD), 'short_cut': Field(name='short_cut',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses a short cut'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments pertaining to model instantiation.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, readout='sum', max_edge_unroll=None, max_node=None, criterion="{'nll': 1.0}", hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)
__match_args__ = ('hidden_dims', 'batch_norm', 'edge_input_dim', 'short_cut', 'activation', 'concat_hidden', 'readout', 'max_edge_unroll', 'max_node', 'criterion', 'hidden_dim_mlp', 'agent_update_interval', 'gamma', 'reward_temperature', 'baseline_momentum')
__module__ = 'gt4sd.training_pipelines.torchdrug.gcpn.core'
__repr__()

Return repr(self).