gt4sd.algorithms.generation.diffusion.geodiff.core module¶
Summary¶
Classes:
Pipeline for molecular conformation generation using GeoDiff. |
Reference¶
- class GeoDiffPipeline(model_name_or_path, params_json=None)[source]¶
Bases:
object
Pipeline for molecular conformation generation using GeoDiff. The pipeline defined here is slightly different than the pipeline used in diffusers.
GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923
- __init__(model_name_or_path, params_json=None)[source]¶
- GeoDiff pipeline for molecular conformation generation. Code adapted from colab:
https://colab.research.google.com/drive/1pLYYWQhdLuv1q-JtEHGZybxp2RBF8gPs#scrollTo=-3-P4w5sXkRU written by Nathan Lambert.
- Parameters
model_name_or_path (
str
) – pretrained model name or path to model directory.params_json (
Optional
[str
,None
]) – parameters as a JSON file. Defaults to None, a.k.a., use default configuration.
- to(device='cuda')[source]¶
Move model to a device.
- Parameters
device (
str
) – device where to move the model. Defaults to “cuda”.- Return type
None
- classmethod from_pretrained(model_name_or_path, params_json=None)[source]¶
Load pretrained model.
- Parameters
model_name_or_path (
str
) – pretrained model name or path to model directory.params_json (
Optional
[str
,None
]) – path to model config.
- Return type
- Returns
a GeoDiff pipeline.
- __call__(batch_size, prompt)[source]¶
Generate conformations for a molecule.
- Parameters
batch_size (
int
) – number of samples to generate.prompt (
Dict
[str
,Any
]) – torch_geometric.data.Data object containing the molecular graph in 2D format. This information is given as conditioning for the model.
- Return type
Dict
[str
,List
[Mol
]]- Returns
a dict containing a list of postprocessed generated conformations.
- postprocess_output(results)[source]¶
Postprocess output of diffusion pipeline.
- Parameters
results (
List
[Data
]) – list of torch_geometric.data.Data objects containing the molecular graph in 3D format.- Return type
Tuple
[List
[Mol
],List
[Mol
]]- Returns
tuple with list of postprocessed generated conformations and list of postprocessed original conformations.
- visualize_2d_input(data)[source]¶
Visualize 2D input.
- Parameters
data (
Data
) – torch_geometric.data.Data object containing the molecular graph in 2D format.- Return type
None
- visualize_3d(mols_gen)[source]¶
Visualize 3D output.
- Parameters
mols_gen (
List
[Mol
]) – list of generated conformations.- Return type
None
- __dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.diffusion.geodiff.core', '__doc__': 'Pipeline for molecular conformation generation using GeoDiff.\n The pipeline defined here is slightly different than the pipeline used in diffusers.\n\n GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923\n ', '__init__': <function GeoDiffPipeline.__init__>, 'to': <function GeoDiffPipeline.to>, 'from_pretrained': <classmethod(<function GeoDiffPipeline.from_pretrained>)>, '__call__': <function GeoDiffPipeline.__call__>, 'postprocess_output': <function GeoDiffPipeline.postprocess_output>, 'visualize_2d_input': <function GeoDiffPipeline.visualize_2d_input>, 'visualize_3d': <function GeoDiffPipeline.visualize_3d>, '__dict__': <attribute '__dict__' of 'GeoDiffPipeline' objects>, '__weakref__': <attribute '__weakref__' of 'GeoDiffPipeline' objects>, '__annotations__': {}})¶
- __doc__ = 'Pipeline for molecular conformation generation using GeoDiff.\n The pipeline defined here is slightly different than the pipeline used in diffusers.\n\n GeoDiff: a Geometric Diffusion Model for Molecular Conformation Generation, Minkai Xu, Lantao Yu, Yang Song, Chence Shi, Stefano Ermon, Jian Tang - https://arxiv.org/abs/2203.02923\n '¶
- __module__ = 'gt4sd.algorithms.generation.diffusion.geodiff.core'¶
- __weakref__¶
list of weak references to the object (if defined)