gt4sd.algorithms.generation.diffusion.geodiff.core module

Summary

Classes:

GeoDiffPipeline

Pipeline for molecular conformation generation using GeoDiff.

Reference

class GeoDiffPipeline(model_name_or_path, params_json=None)[source]

Bases: object

Pipeline for molecular conformation generation using GeoDiff. The pipeline defined here is slightly different than the pipeline used in diffusers.

GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923

__init__(model_name_or_path, params_json=None)[source]
GeoDiff pipeline for molecular conformation generation. Code adapted from colab:

https://colab.research.google.com/drive/1pLYYWQhdLuv1q-JtEHGZybxp2RBF8gPs#scrollTo=-3-P4w5sXkRU written by Nathan Lambert.

Parameters
  • model_name_or_path (str) – pretrained model name or path to model directory.

  • params_json (Optional[str, None]) – parameters as a JSON file. Defaults to None, a.k.a., use default configuration.

to(device='cuda')[source]

Move model to a device.

Parameters

device (str) – device where to move the model. Defaults to “cuda”.

Return type

None

classmethod from_pretrained(model_name_or_path, params_json=None)[source]

Load pretrained model.

Parameters
  • model_name_or_path (str) – pretrained model name or path to model directory.

  • params_json (Optional[str, None]) – path to model config.

Return type

GeoDiffPipeline

Returns

a GeoDiff pipeline.

__call__(batch_size, prompt)[source]

Generate conformations for a molecule.

Parameters
  • batch_size (int) – number of samples to generate.

  • prompt (Dict[str, Any]) – torch_geometric.data.Data object containing the molecular graph in 2D format. This information is given as conditioning for the model.

Return type

Dict[str, List[Mol]]

Returns

a dict containing a list of postprocessed generated conformations.

postprocess_output(results)[source]

Postprocess output of diffusion pipeline.

Parameters

results (List[Data]) – list of torch_geometric.data.Data objects containing the molecular graph in 3D format.

Return type

Tuple[List[Mol], List[Mol]]

Returns

tuple with list of postprocessed generated conformations and list of postprocessed original conformations.

visualize_2d_input(data)[source]

Visualize 2D input.

Parameters

data (Data) – torch_geometric.data.Data object containing the molecular graph in 2D format.

Return type

None

visualize_3d(mols_gen)[source]

Visualize 3D output.

Parameters

mols_gen (List[Mol]) – list of generated conformations.

Return type

None

__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.diffusion.geodiff.core', '__doc__': 'Pipeline for molecular conformation generation using GeoDiff.\n    The pipeline defined here is slightly different than the pipeline used in diffusers.\n\n    GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923\n    ', '__init__': <function GeoDiffPipeline.__init__>, 'to': <function GeoDiffPipeline.to>, 'from_pretrained': <classmethod(<function GeoDiffPipeline.from_pretrained>)>, '__call__': <function GeoDiffPipeline.__call__>, 'postprocess_output': <function GeoDiffPipeline.postprocess_output>, 'visualize_2d_input': <function GeoDiffPipeline.visualize_2d_input>, 'visualize_3d': <function GeoDiffPipeline.visualize_3d>, '__dict__': <attribute '__dict__' of 'GeoDiffPipeline' objects>, '__weakref__': <attribute '__weakref__' of 'GeoDiffPipeline' objects>, '__annotations__': {}})
__doc__ = 'Pipeline for molecular conformation generation using GeoDiff.\n    The pipeline defined here is slightly different than the pipeline used in diffusers.\n\n    GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923\n    '
__module__ = 'gt4sd.algorithms.generation.diffusion.geodiff.core'
__weakref__

list of weak references to the object (if defined)