gt4sd.training_pipelines.torchdrug.graphaf.core module¶
TorchDrug GraphAF training utilities.
Summary¶
Classes:
Arguments pertaining to model instantiation. |
|
TorchDrug GraphAF training pipelines. |
Reference¶
- class TorchDrugGraphAFTrainingPipeline[source]¶
Bases:
TorchDrugTrainingPipeline
TorchDrug GraphAF training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
- Generic training function for training a
- (GraphAF) model. For details see:
Shi, Chence, et al. “GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation”. International Conference on Learning Representations (ICLR), 2020.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'TorchDrug GraphAF training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.graphaf.core'¶
- class TorchDrugGraphAFModelArguments(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, num_node_flow_layers=12, num_edge_flow_layers=12, no_edge=False, readout='sum', max_edge_unroll=12, max_node=38, criterion="{'nll': 1.0}", num_node_sample=-1, num_edge_sample=-1, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)[source]¶
Bases:
TrainingPipelineArguments
Arguments pertaining to model instantiation.
- __name__ = 'TorchDrugGraphAFModelArguments'¶
- batch_norm: bool = False¶
- edge_input_dim: Optional[int] = None¶
- short_cut: bool = False¶
- activation: str = 'relu'¶
- num_node_flow_layers: int = 12¶
- num_edge_flow_layers: int = 12¶
- no_edge: bool = False¶
- readout: str = 'sum'¶
- max_edge_unroll: int = 12¶
- max_node: int = 38¶
- criterion: str = "{'nll': 1.0}"¶
- num_node_sample: int = -1¶
- num_edge_sample: int = -1¶
- agent_update_interval: int = 10¶
- gamma: float = 0.9¶
- reward_temperature: float = 1.0¶
- baseline_momentum: float = 0.9¶
- __annotations__ = {'activation': <class 'str'>, 'agent_update_interval': <class 'int'>, 'baseline_momentum': <class 'float'>, 'batch_norm': <class 'bool'>, 'concat_hidden': <class 'bool'>, 'criterion': <class 'str'>, 'edge_input_dim': typing.Optional[int], 'gamma': <class 'float'>, 'hidden_dims': <class 'str'>, 'max_edge_unroll': <class 'int'>, 'max_node': <class 'int'>, 'no_edge': <class 'bool'>, 'num_edge_flow_layers': <class 'int'>, 'num_edge_sample': <class 'int'>, 'num_node_flow_layers': <class 'int'>, 'num_node_sample': <class 'int'>, 'readout': <class 'str'>, 'reward_temperature': <class 'float'>, 'short_cut': <class 'bool'>}¶
- __dataclass_fields__ = {'activation': Field(name='activation',type=<class 'str'>,default='relu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Activation function for RGCN'}),kw_only=False,_field_type=_FIELD), 'agent_update_interval': Field(name='agent_update_interval',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Update the agent every n batches (similar to gradient accumulation)'}),kw_only=False,_field_type=_FIELD), 'baseline_momentum': Field(name='baseline_momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum for value function baseline'}),kw_only=False,_field_type=_FIELD), 'batch_norm': Field(name='batch_norm',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses batch normalization'}),kw_only=False,_field_type=_FIELD), 'concat_hidden': Field(name='concat_hidden',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether hidden representations from all layers are concatenated'}),kw_only=False,_field_type=_FIELD), 'criterion': Field(name='criterion',type=<class 'str'>,default="{'nll': 1.0}",default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'training criterion. Available criteria are `nll` and `ppo` for regular training and property optimization respectively. If dict, the keys are criterions and values are the corresponding weights. If list, both criteria are used with equal weights.'}),kw_only=False,_field_type=_FIELD), 'edge_input_dim': Field(name='edge_input_dim',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimension of edge features'}),kw_only=False,_field_type=_FIELD), 'gamma': Field(name='gamma',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Reward discount rate'}),kw_only=False,_field_type=_FIELD), 'hidden_dims': Field(name='hidden_dims',type=<class 'str'>,default='[128, 128]',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dimensionality of each hidden layer'}),kw_only=False,_field_type=_FIELD), 'max_edge_unroll': Field(name='max_edge_unroll',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max node id difference. Inferred from training data if not provided'}),kw_only=False,_field_type=_FIELD), 'max_node': Field(name='max_node',type=<class 'int'>,default=38,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'max number of node. Inferred from training data if not provided.'}),kw_only=False,_field_type=_FIELD), 'no_edge': Field(name='no_edge',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use edge features in the edge GraphAF model. Per default, edges are used.'}),kw_only=False,_field_type=_FIELD), 'num_edge_flow_layers': Field(name='num_edge_flow_layers',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers in the edge flow GraphAF model'}),kw_only=False,_field_type=_FIELD), 'num_edge_sample': Field(name='num_edge_sample',type=<class 'int'>,default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of edge samples per graph.'}),kw_only=False,_field_type=_FIELD), 'num_node_flow_layers': Field(name='num_node_flow_layers',type=<class 'int'>,default=12,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers in the node flow GraphAF model'}),kw_only=False,_field_type=_FIELD), 'num_node_sample': Field(name='num_node_sample',type=<class 'int'>,default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of node samples per graph.'}),kw_only=False,_field_type=_FIELD), 'readout': Field(name='readout',type=<class 'str'>,default='sum',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'RGCN Readout function. Either `sum` or `mean`'}),kw_only=False,_field_type=_FIELD), 'reward_temperature': Field(name='reward_temperature',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the reward (larger -> higher mean reward)lower -> higher maximal reward.'}),kw_only=False,_field_type=_FIELD), 'short_cut': Field(name='short_cut',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RGCN uses a short cut'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments pertaining to model instantiation.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(hidden_dims='[128, 128]', batch_norm=False, edge_input_dim=None, short_cut=False, activation='relu', concat_hidden=False, num_node_flow_layers=12, num_edge_flow_layers=12, no_edge=False, readout='sum', max_edge_unroll=12, max_node=38, criterion="{'nll': 1.0}", num_node_sample=-1, num_edge_sample=-1, agent_update_interval=10, gamma=0.9, reward_temperature=1.0, baseline_momentum=0.9)¶
- __match_args__ = ('hidden_dims', 'batch_norm', 'edge_input_dim', 'short_cut', 'activation', 'concat_hidden', 'num_node_flow_layers', 'num_edge_flow_layers', 'no_edge', 'readout', 'max_edge_unroll', 'max_node', 'criterion', 'num_node_sample', 'num_edge_sample', 'agent_update_interval', 'gamma', 'reward_temperature', 'baseline_momentum')¶
- __module__ = 'gt4sd.training_pipelines.torchdrug.graphaf.core'¶
- __repr__()¶
Return repr(self).