gt4sd.algorithms.generation.torchdrug.implementation module

Implementation details for TorchDrug generation algorithms.

Parts of the implementation inspired by: https://torchdrug.ai/docs/tutorials/generation.html.

Summary

Classes:

DummyDataset

A helper class to imitate a torchdrug dataset.

GAFGenerator

Interface for the GraphAF model as implemented in TorchDrug.

GCPNGenerator

Interface for the GCPN model as implemented in TorchDrug.

Generator

Implementation of a TorchDrug generator.

Reference

class DummyDataset(atom_types)[source]

Bases: object

A helper class to imitate a torchdrug dataset.

__init__(atom_types)[source]
__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.torchdrug.implementation', '__doc__': 'A helper class to imitate a torchdrug dataset.', '__init__': <function DummyDataset.__init__>, '__dict__': <attribute '__dict__' of 'DummyDataset' objects>, '__weakref__': <attribute '__weakref__' of 'DummyDataset' objects>, '__annotations__': {}})
__doc__ = 'A helper class to imitate a torchdrug dataset.'
__module__ = 'gt4sd.algorithms.generation.torchdrug.implementation'
__weakref__

list of weak references to the object (if defined)

class Generator(resources_path, atom_types, hidden_dims, input_dim, num_relation, batch_norm, device=None)[source]

Bases: object

Implementation of a TorchDrug generator.

solver: Engine
task: Union[GCPNGeneration, AutoregressiveGeneration]
num_sample: int = 32
max_resample: int = 16
__init__(resources_path, atom_types, hidden_dims, input_dim, num_relation, batch_norm, device=None)[source]

A TorchDrug generation algorithm.

Parameters
  • resources_path (str) – path to the cache.

  • atom_types (List[int]) – list of atom types.

  • hidden_dims (List[int]) – list of hidden dimensions, one per layer.

  • num_relation (int) – number of relations for the graph.

  • batch_norm (bool) – whether to use batch normalization.

  • device (Union[device, str, None]) – device where the inference is running either as a dedicated class or a string. If not provided, it is inferred.

load_model(resources_path)[source]

Load a pretrained TorchDrug model.

sample()[source]

Sample a molecule.

Return type

List[str]

Returns

a generated SMILES string wrapped into a list.

__annotations__ = {'max_resample': <class 'int'>, 'num_sample': <class 'int'>, 'solver': <class 'torchdrug.core.engine.Engine'>, 'task': typing.Union[torchdrug.tasks.generation.GCPNGeneration, torchdrug.tasks.generation.AutoregressiveGeneration]}
__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.torchdrug.implementation', '__annotations__': {'solver': <class 'torchdrug.core.engine.Engine'>, 'task': typing.Union[torchdrug.tasks.generation.GCPNGeneration, torchdrug.tasks.generation.AutoregressiveGeneration], 'num_sample': <class 'int'>, 'max_resample': <class 'int'>}, '__doc__': 'Implementation of a TorchDrug generator.', 'num_sample': 32, 'max_resample': 16, '__init__': <function Generator.__init__>, 'load_model': <function Generator.load_model>, 'sample': <function Generator.sample>, '__dict__': <attribute '__dict__' of 'Generator' objects>, '__weakref__': <attribute '__weakref__' of 'Generator' objects>})
__doc__ = 'Implementation of a TorchDrug generator.'
__module__ = 'gt4sd.algorithms.generation.torchdrug.implementation'
__weakref__

list of weak references to the object (if defined)

class GCPNGenerator(resources_path)[source]

Bases: Generator

Interface for the GCPN model as implemented in TorchDrug.

For details see: You, J. et al. (2018). Graph convolutional policy network for goal-directed molecular graph generation. Advances in neural information processing systems, 31.

input_dim: int = 18
num_relation: int = 3
batch_norm: bool = False
atom_types: List[int] = [6, 7, 8, 9, 15, 16, 17, 35, 53]
hidden_dims: List[int] = [256, 256, 256, 256]
__init__(resources_path)[source]

A TorchDrug Graph-convolutional policy network (GCPN) generation algorithm.

Parameters

resources_path (str) – path to the cache.

task: GCPNGeneration
__annotations__ = {'atom_types': typing.List[int], 'batch_norm': <class 'bool'>, 'hidden_dims': typing.List[int], 'input_dim': <class 'int'>, 'max_resample': 'int', 'num_relation': <class 'int'>, 'num_sample': 'int', 'solver': 'Engine', 'task': <class 'torchdrug.tasks.generation.GCPNGeneration'>}
__doc__ = '\n    Interface for the GCPN model as implemented in TorchDrug.\n\n    For details see:\n    You, J. et al. (2018). Graph convolutional policy network for goal-directed\n    molecular graph generation. Advances in neural information processing systems, 31.\n\n    '
__module__ = 'gt4sd.algorithms.generation.torchdrug.implementation'
class GAFGenerator(resources_path)[source]

Bases: Generator

Interface for the GraphAF model as implemented in TorchDrug.

For details see: Shi, Chence, et al. “GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation” International Conference on Learning Representations (ICLR), 2020.

input_dim: int = 9
num_relation: int = 3
batch_norm: bool = True
atom_types: List[int] = [6, 7, 8, 9, 15, 16, 17, 35, 53]
hidden_dims: List[int] = [256, 256, 256]
__annotations__ = {'atom_types': typing.List[int], 'batch_norm': <class 'bool'>, 'hidden_dims': typing.List[int], 'input_dim': <class 'int'>, 'max_resample': 'int', 'num_relation': <class 'int'>, 'num_sample': 'int', 'solver': 'Engine', 'task': <class 'torchdrug.tasks.generation.AutoregressiveGeneration'>}
__doc__ = '\n    Interface for the GraphAF model as implemented in TorchDrug.\n\n    For details see:\n    Shi, Chence, et al. "GraphAF: a Flow-based Autoregressive Model for Molecular\n    Graph Generation" International Conference on Learning Representations (ICLR), 2020.\n    '
__init__(resources_path)[source]

A TorchDrug flow-based autoregressive graph generation algorithm (GAF).

Parameters

resources_path (str) – path to the cache.

__module__ = 'gt4sd.algorithms.generation.torchdrug.implementation'
task: AutoregressiveGeneration