#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import logging
from dataclasses import dataclass
from typing import Callable, List, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.utils import BaseOutput
from torch import Tensor, nn
from torch.nn import Embedding, Linear, Module, ModuleList, Sequential
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_sparse import SparseTensor
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]@dataclass
class MoleculeGNNOutput(BaseOutput):
"""Hidden states output. Output of last layer of model."""
sample: torch.FloatTensor
[docs]class MultiLayerPerceptron(nn.Module):
"""Multi-layer Perceptron. Note there is no activation or dropout in the last layer."""
[docs] def __init__(
self,
input_dim: int,
hidden_dims: List[int],
activation: str = "relu",
dropout: float = 0,
) -> None:
"""Initialize multi-layer perceptron.
Args:
input_dim: input dimension
hidden_dim: hidden dimensions
activation: activation function
dropout: dropout rate
"""
super(MultiLayerPerceptron, self).__init__()
self.dims = [input_dim] + hidden_dims
if isinstance(activation, str):
self.activation = getattr(F, activation)
else:
logger.info(
f"Warning, activation passed {activation} is not string and ignored"
)
self.activation = None
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None # type: ignore
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 1):
self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: input tensor of shape (batch_size, input_dim)
Returns:
output mlp.
"""
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
if self.activation:
x = self.activation(x)
if self.dropout:
x = self.dropout(x)
return x
[docs]class ShiftedSoftplus(torch.nn.Module):
"""Shifted softplus activation function."""
[docs] def __init__(self) -> None:
super(ShiftedSoftplus, self).__init__()
self.shift = torch.log(torch.tensor(2.0)).item()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.softplus(x) - self.shift
[docs]class CFConv(MessagePassing):
"""CFConv layer."""
[docs] def __init__(
self,
in_channels: int,
out_channels: int,
num_filters: int,
mlp: Callable,
cutoff: float,
smooth: bool,
) -> None:
"""Construct a CFConv layer.
Args:
in_channels: size of each input.
out_channels: size of each output.
num_filters: number of filters.
mlp: mlp hidden dimensions.
cutoff: cutoff distance.
smooth: whether to use smooth cutoff.
"""
super(CFConv, self).__init__(aggr="add")
self.lin1 = Linear(in_channels, num_filters, bias=False)
self.lin2 = Linear(num_filters, out_channels)
self.nn = mlp
self.cutoff = cutoff
self.smooth = smooth
self.reset_parameters()
[docs] def reset_parameters(self) -> None:
"""Initialize parameters."""
torch.nn.init.xavier_uniform_(self.lin1.weight)
torch.nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
[docs] def forward(
self, x: torch.Tensor, edge_index, edge_length, edge_attr
) -> torch.Tensor:
"""Forward pass.
Args:
x: input tensor.
edge_index: edge indices.
edge_length: edge lengths.
edge_attr: edge attributes.
Returns:
output tensor.
"""
if self.smooth:
C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)
C = (
C * (edge_length <= self.cutoff) * (edge_length >= 0.0)
) # Modification: cutoff
else:
C = (edge_length <= self.cutoff).float()
W = self.nn(edge_attr) * C.view(-1, 1)
x = self.lin1(x)
x = self.propagate(edge_index, x=x, W=W)
x = self.lin2(x)
return x
[docs] def message(self, x_j: torch.Tensor, W) -> torch.Tensor:
return x_j * W
[docs]class InteractionBlock(torch.nn.Module):
"""Interaction block."""
[docs] def __init__(
self,
hidden_channels: int,
num_gaussians: int,
num_filters: int,
cutoff: float,
smooth: bool,
) -> None:
"""Construct an interaction block.
Args:
hidden_channels: number of hidden channels.
num_gaussians: number of gaussians.
num_filters: number of filters.
cutoff: cutoff distance.
smooth: whether to use smooth cutoff.
"""
super(InteractionBlock, self).__init__()
mlp = Sequential(
Linear(num_gaussians, num_filters),
ShiftedSoftplus(),
Linear(num_filters, num_filters),
)
self.conv = CFConv(
hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth
)
self.act = ShiftedSoftplus()
self.lin = Linear(hidden_channels, hidden_channels)
[docs] def forward(
self, x: torch.Tensor, edge_index, edge_length, edge_attr
) -> torch.Tensor:
"""Forward pass.
Args:
x: input tensor.
edge_index: edge indices.
edge_length: edge lengths.
edge_attr: edge attributes.
Returns:
output tensor.
"""
x = self.conv(x, edge_index, edge_length, edge_attr)
x = self.act(x)
x = self.lin(x)
return x
[docs]class SchNetEncoder(Module):
"""SchNet encoder."""
[docs] def __init__(
self,
hidden_channels: int = 128,
num_filters: int = 128,
num_interactions: int = 6,
edge_channels: int = 100,
cutoff: float = 10.0,
smooth: bool = False,
) -> None:
"""Construct a SchNet encoder.
Args:
hidden_channels: number of hidden channels.
num_filters: number of filters.
num_interactions: number of interactions.
edge_channels: number of edge channels.
cutoff: cutoff distance.
smooth: whether to use smooth cutoff.
"""
super().__init__()
self.hidden_channels = hidden_channels
self.num_filters = num_filters
self.num_interactions = num_interactions
self.cutoff = cutoff
self.embedding = Embedding(100, hidden_channels, max_norm=10.0)
self.interactions = ModuleList()
for _ in range(num_interactions):
block = InteractionBlock(
hidden_channels, edge_channels, num_filters, cutoff, smooth
)
self.interactions.append(block)
[docs] def forward(
self,
z: torch.Tensor,
edge_index: torch.Tensor,
edge_length: torch.Tensor,
edge_attr: torch.Tensor,
embed_node: bool = True,
) -> torch.Tensor:
"""Forward pass.
Args:
z: input tensor.
edge_index: edge indices.
edge_length: edge lengths.
edge_attr: edge attributes.
embed_node: whether to embed node.
Returns:
output tensor.
"""
if embed_node:
assert z.dim() == 1 and z.dtype == torch.long
h = self.embedding(z)
else:
h = z
for interaction in self.interactions:
h = h + interaction(h, edge_index, edge_length, edge_attr)
return h
[docs]class GINEConv(MessagePassing):
"""GINECONV layer.
Custom class of the graph isomorphism operator from the "How Powerful are Graph Neural Networks?
https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.
"""
[docs] def __init__(
self,
mlp: Callable,
eps: float = 0.0,
train_eps: bool = False,
activation: str = "softplus",
**kwargs,
) -> None:
"""Construct a GINEConv layer.
Args:
mlp: MLP.
eps: epsilon.
train_eps: whether to train epsilon.
activation: activation function.
"""
super(GINEConv, self).__init__(aggr="add", **kwargs)
self.nn = mlp
self.initial_eps = eps
if isinstance(activation, str):
self.activation = getattr(F, activation)
else:
self.activation = None
if train_eps:
self.eps = torch.nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer("eps", torch.Tensor([eps]))
[docs] def forward(
self,
x: Union[Tensor, OptPairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
size: Size = None,
) -> torch.Tensor:
"""Forward pass.
Args:
x: input tensor.
edge_index: edge indices.
edge_attr: edge attributes.
size: size.
Returns:
output tensor.
"""
if isinstance(x, torch.Tensor):
x = (x, x)
# Node and edge feature dimensionalites need to match.
if isinstance(edge_index, torch.Tensor):
assert edge_attr is not None
assert x[0].size(-1) == edge_attr.size(-1)
elif isinstance(edge_index, SparseTensor):
assert x[0].size(-1) == edge_index.size(-1)
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
x_r = x[1]
if x_r is not None:
out += (1 + self.eps) * x_r
return self.nn(out)
[docs] def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
"""Message function.
Args:
x_j: input tensor.
edge_attr: edge attributes.
Returns:
message passing aggregation.
"""
if self.activation:
return self.activation(x_j + edge_attr)
else:
return x_j + edge_attr
[docs] def __repr__(self):
return "{}(nn={})".format(self.__class__.__name__, self.nn)
[docs]class GINEncoder(torch.nn.Module):
"""GIN encoder."""
[docs] def __init__(
self,
hidden_dim: int,
num_convs: int = 3,
activation: str = "relu",
short_cut: bool = True,
concat_hidden: bool = False,
) -> None:
"""Construct a GIN encoder.
Args:
hidden_dim: number of hidden channels.
num_convs: number of convolutions.
activation: activation function.
short_cut: whether to use short cut.
concat_hidden: whether to concatenate hidden.
"""
super().__init__()
self.hidden_dim = hidden_dim
self.num_convs = num_convs
self.short_cut = short_cut
self.concat_hidden = concat_hidden
self.node_emb = nn.Embedding(100, hidden_dim)
if isinstance(activation, str):
self.activation = getattr(F, activation)
else:
self.activation = None
self.convs = nn.ModuleList()
for i in range(self.num_convs):
self.convs.append(
GINEConv(
MultiLayerPerceptron(
hidden_dim, [hidden_dim, hidden_dim], activation=activation
),
activation=activation,
)
)
[docs] def forward(
self, z: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor
) -> torch.Tensor:
"""Forward pass.
args:
z: input tensor.
edge_index: edge indices.
edge_attr: edge attributes.
returns:
graph with node feature.
"""
node_attr = self.node_emb(z) # (num_node, hidden)
hiddens = []
conv_input = node_attr # (num_node, hidden)
for conv_idx, conv in enumerate(self.convs):
hidden = conv(conv_input, edge_index, edge_attr)
if conv_idx < len(self.convs) - 1 and self.activation is not None:
hidden = self.activation(hidden)
assert hidden.shape == conv_input.shape
if self.short_cut and hidden.shape == conv_input.shape:
hidden += conv_input
hiddens.append(hidden)
conv_input = hidden
if self.concat_hidden:
node_feature = torch.cat(hiddens, dim=-1)
else:
node_feature = hiddens[-1]
return node_feature
[docs]class MLPEdgeEncoder(Module):
"""MLP edge encoder."""
[docs] def __init__(self, hidden_dim: int = 100, activation: str = "relu") -> None:
"""Construct a MLP edge encoder.
Args:
hidden_dim: number of hidden channels.
activation: activation function.
"""
super().__init__()
self.hidden_dim = hidden_dim
self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)
self.mlp = MultiLayerPerceptron(
1, [self.hidden_dim, self.hidden_dim], activation=activation
)
@property
def out_channels(self):
return self.hidden_dim
[docs] def forward(
self, edge_length: torch.Tensor, edge_type: torch.Tensor
) -> torch.Tensor:
"""
Args:
edge_length: The length of edges.
edge_type: The type of edges.
Returns:
the output representation of edges.
"""
d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)
edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)
return d_emb * edge_attr # (num_edge, hidden)