gt4sd.training_pipelines.torchdrug.graphaf.core module

TorchDrug GraphAF training utilities.

Summary

Classes:

TorchDrugGraphAFModelArguments

Arguments pertaining to model instantiation.

TorchDrugGraphAFTrainingPipeline

TorchDrug GraphAF training pipelines.

Reference

class TorchDrugGraphAFTrainingPipeline[source]

Bases: TorchDrugTrainingPipeline

TorchDrug GraphAF training pipelines.

train(training_args, model_args, dataset_args)[source]
Generic training function for training a
(GraphAF) model. For details see:

Shi, Chence, et al. “GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation”. International Conference on Learning Representations (ICLR), 2020.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

None

__annotations__ = {}
__doc__ = 'TorchDrug GraphAF training pipelines.'
__module__ = 'gt4sd.training_pipelines.torchdrug.graphaf.core'
class TorchDrugGraphAFModelArguments(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, num_node_flow_layers=12, num_edge_flow_layers=12, no_edge=False, readout='sum', max_edge_unroll=12, max_node=38, criterion="{'nll': 1.0}", num_node_sample=-1, num_edge_sample=-1, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)[source]

Bases: TrainingPipelineArguments

Arguments pertaining to model instantiation.

__name__ = 'TorchDrugGraphAFModelArguments'
hidden_dims: str = '[128, 128]'
batch_norm: bool = False
edge_input_dim: Optional[int] = None
short_cut: bool = False
activation: str = 'relu'
concat_hidden: bool = False
num_node_flow_layers: int = 12
num_edge_flow_layers: int = 12
no_edge: bool = False
readout: str = 'sum'
max_edge_unroll: int = 12
max_node: int = 38
criterion: str = "{'nll': 1.0}"
num_node_sample: int = -1
num_edge_sample: int = -1
agent_update_interval: int = 10
gamma: float = 0.9
reward_temperature: float = 1.0
baseline_momentum: float = 0.9
__annotations__ = {'activation': <class 'str'>, 'agent_update_interval': <class 'int'>, 'baseline_momentum': <class 'float'>, 'batch_norm': <class 'bool'>, 'concat_hidden': <class 'bool'>, 'criterion': <class 'str'>, 'edge_input_dim': typing.Optional[int], 'gamma': <class 'float'>, 'hidden_dims': <class 'str'>, 'max_edge_unroll': <class 'int'>, 'max_node': <class 'int'>, 'no_edge': <class 'bool'>, 'num_edge_flow_layers': <class 'int'>, 'num_edge_sample': <class 'int'>, 'num_node_flow_layers': <class 'int'>, 'num_node_sample': <class 'int'>, 'readout': <class 'str'>, 'reward_temperature': <class 'float'>, 'short_cut': <class 'bool'>}
__dataclass_fields__ = {'activation': Field(name='activation',type=<class 'str'>,default='relu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Activation function for RGCN'}),kw_only=False,_field_type=_FIELD), 'agent_update_interval': Field(name='agent_update_interval',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Update the agent every n batches (similar to gradient accumulation)'}),kw_only=False,_field_type=_FIELD), 'baseline_momentum': Field(name='baseline_momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum for value function baseline'}),kw_only=False,_field_type=_FIELD), 'batch_norm': Field(name='batch_norm',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses batch normalization'}),kw_only=False,_field_type=_FIELD), 'concat_hidden': Field(name='concat_hidden',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hidden representations from all layers are concatenated'}),kw_only=False,_field_type=_FIELD), 'criterion': Field(name='criterion',type=<class 'str'>,default="{'nll': 1.0}",default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'training criterion. Available criteria are `nll` and `ppo` for regular training and property optimization respectively. If dict, the keys are criterions and values are the corresponding weights. If list, both criteria are used with equal weights.'}),kw_only=False,_field_type=_FIELD), 'edge_input_dim': Field(name='edge_input_dim',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimension of edge features'}),kw_only=False,_field_type=_FIELD), 'gamma': Field(name='gamma',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Reward discount rate'}),kw_only=False,_field_type=_FIELD), 'hidden_dims': Field(name='hidden_dims',type=<class 'str'>,default='[128, 128]',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimensionality of each hidden layer'}),kw_only=False,_field_type=_FIELD), 'max_edge_unroll': Field(name='max_edge_unroll',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max node id difference. Inferred from training data if not provided'}),kw_only=False,_field_type=_FIELD), 'max_node': Field(name='max_node',type=<class 'int'>,default=38,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max number of node. Inferred from training data if not provided.'}),kw_only=False,_field_type=_FIELD), 'no_edge': Field(name='no_edge',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use edge features in the edge GraphAF model. Per default, edges are used.'}),kw_only=False,_field_type=_FIELD), 'num_edge_flow_layers': Field(name='num_edge_flow_layers',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers in the edge flow GraphAF model'}),kw_only=False,_field_type=_FIELD), 'num_edge_sample': Field(name='num_edge_sample',type=<class 'int'>,default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of edge samples per graph.'}),kw_only=False,_field_type=_FIELD), 'num_node_flow_layers': Field(name='num_node_flow_layers',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers in the node flow GraphAF model'}),kw_only=False,_field_type=_FIELD), 'num_node_sample': Field(name='num_node_sample',type=<class 'int'>,default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of node samples per graph.'}),kw_only=False,_field_type=_FIELD), 'readout': Field(name='readout',type=<class 'str'>,default='sum',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'RGCN Readout function. Either `sum` or `mean`'}),kw_only=False,_field_type=_FIELD), 'reward_temperature': Field(name='reward_temperature',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the reward (larger -> higher mean reward)lower -> higher maximal reward.'}),kw_only=False,_field_type=_FIELD), 'short_cut': Field(name='short_cut',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses a short cut'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments pertaining to model instantiation.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, num_node_flow_layers=12, num_edge_flow_layers=12, no_edge=False, readout='sum', max_edge_unroll=12, max_node=38, criterion="{'nll': 1.0}", num_node_sample=-1, num_edge_sample=-1, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)
__match_args__ = ('hidden_dims', 'batch_norm', 'edge_input_dim', 'short_cut', 'activation', 'concat_hidden', 'num_node_flow_layers', 'num_edge_flow_layers', 'no_edge', 'readout', 'max_edge_unroll', 'max_node', 'criterion', 'num_node_sample', 'num_edge_sample', 'agent_update_interval', 'gamma', 'reward_temperature', 'baseline_momentum')
__module__ = 'gt4sd.training_pipelines.torchdrug.graphaf.core'
__repr__()

Return repr(self).