gt4sd.frameworks.gflownet.ml.models.mxmnet module¶
Summary¶
Classes:
Bessel Basis Layer. |
|
EMA - Exponential Moving Average. |
|
Envelope. |
|
Global message passing. |
|
Local message passing. |
|
MXMNet - Multiplex Molecular Graph Neural Network |
|
MXMNet configuration. |
|
Message Passing Layer. |
|
Residual Block. |
|
SiLU Activation Function. |
|
Spherical Basis Layer. |
Functions:
multi-layer perceptron. |
|
Compute the indices for the edges and angles. |
|
Converts a RDKit molecule to a graph. |
|
An function that finds the lowest energy conformation of a molecule. |
|
Reference¶
- class MXMNetConfig(dim, n_layer, cutoff)[source]¶
Bases:
object
MXMNet configuration.
- __init__(dim, n_layer, cutoff)[source]¶
Initialize MXMNet configuration.
- Parameters
dim – dimension of the input.
n_layer – number of layers.
cutoff – cutoff radius.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.ml.models.mxmnet', '__doc__': 'MXMNet configuration.', '__init__': <function MXMNetConfig.__init__>, '__dict__': <attribute '__dict__' of 'MXMNetConfig' objects>, '__weakref__': <attribute '__weakref__' of 'MXMNetConfig' objects>, '__annotations__': {}})¶
- __doc__ = 'MXMNet configuration.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- __weakref__¶
list of weak references to the object (if defined)
- class MXMNet(config, num_spherical=7, num_radial=6, envelope_exponent=5)[source]¶
Bases:
Module
MXMNet - Multiplex Molecular Graph Neural Network
- __init__(config, num_spherical=7, num_radial=6, envelope_exponent=5)[source]¶
Construct an MXMNet.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/models and https://github.com/zetayue/MXMNet.
- Parameters
config (
MXMNetConfig
) – model configurationnum_spherical – number of spherical harmonics to use.
num_radial – number of radial harmonics to use.
envelope_exponent – exponent of the envelope function.
- indices(edge_index, num_nodes)[source]¶
Compute indices.
- Parameters
edge_index – edge index of the graph.
num_nodes – number of nodes in the graph.
- Returns
tuple of indeces.
- forward(data)[source]¶
Forward pass.
- Parameters
data – batch of data.
- Returns
gloabl pooled features.
- __annotations__ = {}¶
- __doc__ = 'MXMNet - Multiplex Molecular Graph Neural Network'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class EMA(model, decay)[source]¶
Bases:
object
EMA - Exponential Moving Average.
- __init__(model, decay)[source]¶
Initialize ema.
- Parameters
model – model to ema.
decay – decay rate.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.ml.models.mxmnet', '__doc__': 'EMA - Exponential Moving Average.', '__init__': <function EMA.__init__>, '__call__': <function EMA.__call__>, 'assign': <function EMA.assign>, 'resume': <function EMA.resume>, '__dict__': <attribute '__dict__' of 'EMA' objects>, '__weakref__': <attribute '__weakref__' of 'EMA' objects>, '__annotations__': {}})¶
- __doc__ = 'EMA - Exponential Moving Average.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- __weakref__¶
list of weak references to the object (if defined)
- MLP(channels)[source]¶
multi-layer perceptron.
- Parameters
channels – list of number of channels.
- Return type
Sequential
- Returns
MLP model.
- class Res(dim)[source]¶
Bases:
Module
Residual Block.
- forward(m)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- __annotations__ = {}¶
- __doc__ = 'Residual Block.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- compute_idx(pos, edge_index)[source]¶
Compute the indices for the edges and angles.
- Parameters
pos – node positions.
edge_index – edge indices.
- class BesselBasisLayer(num_radial, cutoff, envelope_exponent=6)[source]¶
Bases:
Module
Bessel Basis Layer.
- __init__(num_radial, cutoff, envelope_exponent=6)[source]¶
Initialize Bessel basis layer.
- Parameters
num_radial – number of radial basis functions.
cutoff – cutoff radius.
envelope_exponent – envelope exponent.
- __annotations__ = {}¶
- __doc__ = 'Bessel Basis Layer.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class SiLU[source]¶
Bases:
Module
SiLU Activation Function.
- forward(input)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- __annotations__ = {}¶
- __doc__ = 'SiLU Activation Function.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class Envelope(exponent)[source]¶
Bases:
Module
Envelope.
- __annotations__ = {}¶
- __doc__ = 'Envelope.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class SphericalBasisLayer(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)[source]¶
Bases:
Module
Spherical Basis Layer.
- __init__(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)[source]¶
Initialize spherical basis layer.
- Parameters
num_spherical – number of spherical harmonics.
num_radial – number of radial functions.
cutoff – cutoff radius.
envelope_exponent – envelope exponent.
- forward(dist, angle, idx_kj)[source]¶
Forward pass.
- Parameters
dist – distance matrix.
angle – angle matrix.
idx_kj – index matrix.
- Returns
Spherical basis.
- __annotations__ = {}¶
- __doc__ = 'Spherical Basis Layer.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class MessagePassing(aggr='add', flow='target_to_source', node_dim=0)[source]¶
Bases:
Module
Message Passing Layer.
\[\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),\]where \(\square\) denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and \(\gamma_{\mathbf{\Theta}}\) and \(\phi_{\mathbf{\Theta}}\) denote differentiable functions such as MLPs.
- __init__(aggr='add', flow='target_to_source', node_dim=0)[source]¶
Initialize message passing layer.
- Parameters
aggr – the aggregation scheme to use (add, mean, max).
flow – the flow direction of message passing (source_to_target, target_to_source).
node_dim – the axis along which to propagate.
- propagate(edge_index, size=None, **kwargs)[source]¶
The initial call to start propagating messages.
- Parameters
edge_index (
Tensor
) – the indices of a general (sparse) assignment matrix with shape[N, M]
(can be directed or undirected).size (
Optional
[List
[Tuple
],None
]) – the size[N, M]
of the assignment matrix. If set toNone
, the size will be automatically inferred and assumed to be quadratic.**kwargs – any additional data which is needed to construct and aggregate messages, and to update node embeddings.
- message(x_j)[source]¶
Constructs messages to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in \((j,i) \in \mathcal{E}\) if
flow="source_to_target"
and \((i,j) \in \mathcal{E}\) ifflow="target_to_source"
. Can take any argument which was initially passed topropagate()
. In addition, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- aggregate(inputs, index, dim_size)[source]¶
Aggregates messages from neighbors as \(\square_{j \in \mathcal{N}(i)}\). By default, delegates call to scatter functions that support “add”, “mean” and “max” operations specified in
__init__()
by theaggr
argument.
- update(inputs)[source]¶
Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to
propagate()
.
- __annotations__ = {}¶
- __doc__ = 'Message Passing Layer.\n\n .. math::\n \\mathbf{x}_i^{\\prime} = \\gamma_{\\mathbf{\\Theta}} \\left( \\mathbf{x}_i,\n \\square_{j \\in \\mathcal{N}(i)} \\, \\phi_{\\mathbf{\\Theta}}\n \\left(\\mathbf{x}_i, \\mathbf{x}_j,\\mathbf{e}_{i,j}\\right) \\right),\n where :math:`\\square` denotes a differentiable, permutation invariant\n function, *e.g.*, sum, mean or max, and :math:`\\gamma_{\\mathbf{\\Theta}}`\n and :math:`\\phi_{\\mathbf{\\Theta}}` denote differentiable functions such as\n MLPs.\n '¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- rdkit_conformation(mol, n=5, addHs=False)[source]¶
An function that finds the lowest energy conformation of a molecule.
- Parameters
mol – RDKit molecule object.
n – Number of conformations to find.
addHs – Whether to add hydrogens to the molecule.
- Returns
RDKit molecule object with lowest energy conformation. If none, no conformation is found.
- mol2graph(mol)[source]¶
Converts a RDKit molecule to a graph.
- Parameters
mol – RDKit molecule.
- Returns
A graph with node features and edge features.
- class Global_MP(config)[source]¶
Bases:
MessagePassing
Global message passing.
- __init__(config)[source]¶
Initializes the global message passing.
- Parameters
config – configuration.
- forward(h, edge_attr, edge_index)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- message(x_i, x_j, edge_attr, edge_index, num_nodes)[source]¶
Constructs messages to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in \((j,i) \in \mathcal{E}\) if
flow="source_to_target"
and \((i,j) \in \mathcal{E}\) ifflow="target_to_source"
. Can take any argument which was initially passed topropagate()
. In addition, tensors passed topropagate()
can be mapped to the respective nodes \(i\) and \(j\) by appending_i
or_j
to the variable name, .e.g.x_i
andx_j
.
- update(aggr_out)[source]¶
Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to
propagate()
.
- __annotations__ = {}¶
- __doc__ = 'Global message passing.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- class Local_MP(config)[source]¶
Bases:
Module
Local message passing.
- __annotations__ = {}¶
- __doc__ = 'Local message passing.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.mxmnet'¶
- forward(h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.