gt4sd.frameworks.gflownet.ml.models.mxmnet module

Summary

Classes:

BesselBasisLayer

Bessel Basis Layer.

EMA

EMA - Exponential Moving Average.

Envelope

Envelope.

Global_MP

Global message passing.

Local_MP

Local message passing.

MXMNet

MXMNet - Multiplex Molecular Graph Neural Network

MXMNetConfig

MXMNet configuration.

MessagePassing

Message Passing Layer.

Res

Residual Block.

SiLU

SiLU Activation Function.

SphericalBasisLayer

Spherical Basis Layer.

Functions:

Jn

Jn_zeros

MLP

multi-layer perceptron.

associated_legendre_polynomials

bessel_basis

compute_idx

Compute the indices for the edges and angles.

mol2graph

Converts a RDKit molecule to a graph.

rdkit_conformation

An function that finds the lowest energy conformation of a molecule.

real_sph_harm

silu

sph_harm_prefactor

spherical_bessel_formulas

Reference

class MXMNetConfig(dim, n_layer, cutoff)[source]

Bases: object

MXMNet configuration.

__init__(dim, n_layer, cutoff)[source]

Initialize MXMNet configuration.

Parameters
  • dim – dimension of the input.

  • n_layer – number of layers.

  • cutoff – cutoff radius.

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.ml.models.mxmnet', '__doc__': 'MXMNet configuration.', '__init__': <function MXMNetConfig.__init__>, '__dict__': <attribute '__dict__' of 'MXMNetConfig' objects>, '__weakref__': <attribute '__weakref__' of 'MXMNetConfig' objects>, '__annotations__': {}})
__doc__ = 'MXMNet configuration.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
__weakref__

list of weak references to the object (if defined)

class MXMNet(config, num_spherical=7, num_radial=6, envelope_exponent=5)[source]

Bases: Module

MXMNet - Multiplex Molecular Graph Neural Network

__init__(config, num_spherical=7, num_radial=6, envelope_exponent=5)[source]

Construct an MXMNet.

Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/models and https://github.com/zetayue/MXMNet.

Parameters
  • config (MXMNetConfig) – model configuration

  • num_spherical – number of spherical harmonics to use.

  • num_radial – number of radial harmonics to use.

  • envelope_exponent – exponent of the envelope function.

init()[source]
indices(edge_index, num_nodes)[source]

Compute indices.

Parameters
  • edge_index – edge index of the graph.

  • num_nodes – number of nodes in the graph.

Returns

tuple of indeces.

forward(data)[source]

Forward pass.

Parameters

data – batch of data.

Returns

gloabl pooled features.

__annotations__ = {}
__doc__ = 'MXMNet - Multiplex Molecular Graph Neural Network'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
class EMA(model, decay)[source]

Bases: object

EMA - Exponential Moving Average.

__init__(model, decay)[source]

Initialize ema.

Parameters
  • model – model to ema.

  • decay – decay rate.

__call__(model, num_updates=99999)[source]

Call self as a function.

assign(model)[source]
resume(model)[source]
__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.ml.models.mxmnet', '__doc__': 'EMA - Exponential Moving Average.', '__init__': <function EMA.__init__>, '__call__': <function EMA.__call__>, 'assign': <function EMA.assign>, 'resume': <function EMA.resume>, '__dict__': <attribute '__dict__' of 'EMA' objects>, '__weakref__': <attribute '__weakref__' of 'EMA' objects>, '__annotations__': {}})
__doc__ = 'EMA - Exponential Moving Average.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
__weakref__

list of weak references to the object (if defined)

MLP(channels)[source]

multi-layer perceptron.

Parameters

channels – list of number of channels.

Return type

Sequential

Returns

MLP model.

class Res(dim)[source]

Bases: Module

Residual Block.

__init__(dim)[source]

Initialize residual block.

Parameters

dim – dimension of the layer.

forward(m)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

__annotations__ = {}
__doc__ = 'Residual Block.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
compute_idx(pos, edge_index)[source]

Compute the indices for the edges and angles.

Parameters
  • pos – node positions.

  • edge_index – edge indices.

Jn(r, n)[source]
Jn_zeros(n, k)[source]
spherical_bessel_formulas(n)[source]
bessel_basis(n, k)[source]
sph_harm_prefactor(k, m)[source]
associated_legendre_polynomials(k, zero_m_only=True)[source]
real_sph_harm(k, zero_m_only=True, spherical_coordinates=True)[source]
class BesselBasisLayer(num_radial, cutoff, envelope_exponent=6)[source]

Bases: Module

Bessel Basis Layer.

__init__(num_radial, cutoff, envelope_exponent=6)[source]

Initialize Bessel basis layer.

Parameters
  • num_radial – number of radial basis functions.

  • cutoff – cutoff radius.

  • envelope_exponent – envelope exponent.

reset_parameters()[source]
forward(dist)[source]

Forward pass.

Parameters

dist – distance matrix.

Returns

Bessel basis.

__annotations__ = {}
__doc__ = 'Bessel Basis Layer.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
class SiLU[source]

Bases: Module

SiLU Activation Function.

__init__()[source]

Initialize the SiLU activation function.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

__annotations__ = {}
__doc__ = 'SiLU Activation Function.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
silu(input)[source]
class Envelope(exponent)[source]

Bases: Module

Envelope.

__init__(exponent)[source]

Initialize envelope.

Parameters

exponent – exponent of the envelope.

forward(x)[source]

Forward pass.

Parameters

x – input.

Returns

Envelope of x.

__annotations__ = {}
__doc__ = 'Envelope.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
class SphericalBasisLayer(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)[source]

Bases: Module

Spherical Basis Layer.

__init__(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)[source]

Initialize spherical basis layer.

Parameters
  • num_spherical – number of spherical harmonics.

  • num_radial – number of radial functions.

  • cutoff – cutoff radius.

  • envelope_exponent – envelope exponent.

forward(dist, angle, idx_kj)[source]

Forward pass.

Parameters
  • dist – distance matrix.

  • angle – angle matrix.

  • idx_kj – index matrix.

Returns

Spherical basis.

__annotations__ = {}
__doc__ = 'Spherical Basis Layer.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
class MessagePassing(aggr='add', flow='target_to_source', node_dim=0)[source]

Bases: Module

Message Passing Layer.

\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),\]

where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.

__init__(aggr='add', flow='target_to_source', node_dim=0)[source]

Initialize message passing layer.

Parameters
  • aggr – the aggregation scheme to use (add, mean, max).

  • flow – the flow direction of message passing (source_to_target, target_to_source).

  • node_dim – the axis along which to propagate.

__set_size__(size, index, tensor)[source]
__collect__(edge_index, size, kwargs)[source]
__distribute__(params, kwargs)[source]
propagate(edge_index, size=None, **kwargs)[source]

The initial call to start propagating messages.

Parameters
  • edge_index (Tensor) – the indices of a general (sparse) assignment matrix with shape [N, M] (can be directed or undirected).

  • size (Optional[List[Tuple], None]) – the size [N, M] of the assignment matrix. If set to None, the size will be automatically inferred and assumed to be quadratic.

  • **kwargs – any additional data which is needed to construct and aggregate messages, and to update node embeddings.

message(x_j)[source]

Constructs messages to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in \((j,i) \in \mathcal{E}\) if flow="source_to_target" and \((i,j) \in \mathcal{E}\) if flow="target_to_source". Can take any argument which was initially passed to propagate(). In addition, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

aggregate(inputs, index, dim_size)[source]

Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\). By default, delegates call to scatter functions that support “add”, “mean” and “max” operations specified in __init__() by the aggr argument.

update(inputs)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

__annotations__ = {}
__doc__ = 'Message Passing Layer.\n\n    .. math::\n        \\mathbf{x}_i^{\\prime} = \\gamma_{\\mathbf{\\Theta}} \\left( \\mathbf{x}_i,\n        \\square_{j \\in \\mathcal{N}(i)} \\, \\phi_{\\mathbf{\\Theta}}\n        \\left(\\mathbf{x}_i, \\mathbf{x}_j,\\mathbf{e}_{i,j}\\right) \\right),\n    where :math:`\\square` denotes a differentiable, permutation invariant\n    function, *e.g.*, sum, mean or max, and :math:`\\gamma_{\\mathbf{\\Theta}}`\n    and :math:`\\phi_{\\mathbf{\\Theta}}` denote differentiable functions such as\n    MLPs.\n    '
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
rdkit_conformation(mol, n=5, addHs=False)[source]

An function that finds the lowest energy conformation of a molecule.

Parameters
  • mol – RDKit molecule object.

  • n – Number of conformations to find.

  • addHs – Whether to add hydrogens to the molecule.

Returns

RDKit molecule object with lowest energy conformation. If none, no conformation is found.

mol2graph(mol)[source]

Converts a RDKit molecule to a graph.

Parameters

mol – RDKit molecule.

Returns

A graph with node features and edge features.

class Global_MP(config)[source]

Bases: MessagePassing

Global message passing.

__init__(config)[source]

Initializes the global message passing.

Parameters

config – configuration.

forward(h, edge_attr, edge_index)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

message(x_i, x_j, edge_attr, edge_index, num_nodes)[source]

Constructs messages to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in \((j,i) \in \mathcal{E}\) if flow="source_to_target" and \((i,j) \in \mathcal{E}\) if flow="target_to_source". Can take any argument which was initially passed to propagate(). In addition, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_out)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

__annotations__ = {}
__doc__ = 'Global message passing.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
class Local_MP(config)[source]

Bases: Module

Local message passing.

__init__(config)[source]

Initialize local message passing.

Parameters

config – configuration.

__annotations__ = {}
__doc__ = 'Local message passing.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'
forward(h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.