Source code for gt4sd.algorithms.generation.diffusion.geodiff.model.layers

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import logging
from dataclasses import dataclass
from typing import Callable, List, Union

import numpy as np
import torch
import torch.nn.functional as F
from diffusers.utils import BaseOutput
from torch import Tensor, nn
from torch.nn import Embedding, Linear, Module, ModuleList, Sequential
from torch_geometric.nn import MessagePassing
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_sparse import SparseTensor

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]@dataclass class MoleculeGNNOutput(BaseOutput): """Hidden states output. Output of last layer of model.""" sample: torch.FloatTensor
[docs]class MultiLayerPerceptron(nn.Module): """Multi-layer Perceptron. Note there is no activation or dropout in the last layer."""
[docs] def __init__( self, input_dim: int, hidden_dims: List[int], activation: str = "relu", dropout: float = 0, ) -> None: """Initialize multi-layer perceptron. Args: input_dim: input dimension hidden_dim: hidden dimensions activation: activation function dropout: dropout rate """ super(MultiLayerPerceptron, self).__init__() self.dims = [input_dim] + hidden_dims if isinstance(activation, str): self.activation = getattr(F, activation) else: logger.info( f"Warning, activation passed {activation} is not string and ignored" ) self.activation = None if dropout > 0: self.dropout = nn.Dropout(dropout) else: self.dropout = None # type: ignore self.layers = nn.ModuleList() for i in range(len(self.dims) - 1): self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: input tensor of shape (batch_size, input_dim) Returns: output mlp. """ for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: if self.activation: x = self.activation(x) if self.dropout: x = self.dropout(x) return x
[docs]class ShiftedSoftplus(torch.nn.Module): """Shifted softplus activation function."""
[docs] def __init__(self) -> None: super(ShiftedSoftplus, self).__init__() self.shift = torch.log(torch.tensor(2.0)).item()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return F.softplus(x) - self.shift
[docs]class CFConv(MessagePassing): """CFConv layer."""
[docs] def __init__( self, in_channels: int, out_channels: int, num_filters: int, mlp: Callable, cutoff: float, smooth: bool, ) -> None: """Construct a CFConv layer. Args: in_channels: size of each input. out_channels: size of each output. num_filters: number of filters. mlp: mlp hidden dimensions. cutoff: cutoff distance. smooth: whether to use smooth cutoff. """ super(CFConv, self).__init__(aggr="add") self.lin1 = Linear(in_channels, num_filters, bias=False) self.lin2 = Linear(num_filters, out_channels) self.nn = mlp self.cutoff = cutoff self.smooth = smooth self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Initialize parameters.""" torch.nn.init.xavier_uniform_(self.lin1.weight) torch.nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0)
[docs] def forward( self, x: torch.Tensor, edge_index, edge_length, edge_attr ) -> torch.Tensor: """Forward pass. Args: x: input tensor. edge_index: edge indices. edge_length: edge lengths. edge_attr: edge attributes. Returns: output tensor. """ if self.smooth: C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0) C = ( C * (edge_length <= self.cutoff) * (edge_length >= 0.0) ) # Modification: cutoff else: C = (edge_length <= self.cutoff).float() W = self.nn(edge_attr) * C.view(-1, 1) x = self.lin1(x) x = self.propagate(edge_index, x=x, W=W) x = self.lin2(x) return x
[docs] def message(self, x_j: torch.Tensor, W) -> torch.Tensor: return x_j * W
[docs]class InteractionBlock(torch.nn.Module): """Interaction block."""
[docs] def __init__( self, hidden_channels: int, num_gaussians: int, num_filters: int, cutoff: float, smooth: bool, ) -> None: """Construct an interaction block. Args: hidden_channels: number of hidden channels. num_gaussians: number of gaussians. num_filters: number of filters. cutoff: cutoff distance. smooth: whether to use smooth cutoff. """ super(InteractionBlock, self).__init__() mlp = Sequential( Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters), ) self.conv = CFConv( hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth ) self.act = ShiftedSoftplus() self.lin = Linear(hidden_channels, hidden_channels)
[docs] def forward( self, x: torch.Tensor, edge_index, edge_length, edge_attr ) -> torch.Tensor: """Forward pass. Args: x: input tensor. edge_index: edge indices. edge_length: edge lengths. edge_attr: edge attributes. Returns: output tensor. """ x = self.conv(x, edge_index, edge_length, edge_attr) x = self.act(x) x = self.lin(x) return x
[docs]class SchNetEncoder(Module): """SchNet encoder."""
[docs] def __init__( self, hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, edge_channels: int = 100, cutoff: float = 10.0, smooth: bool = False, ) -> None: """Construct a SchNet encoder. Args: hidden_channels: number of hidden channels. num_filters: number of filters. num_interactions: number of interactions. edge_channels: number of edge channels. cutoff: cutoff distance. smooth: whether to use smooth cutoff. """ super().__init__() self.hidden_channels = hidden_channels self.num_filters = num_filters self.num_interactions = num_interactions self.cutoff = cutoff self.embedding = Embedding(100, hidden_channels, max_norm=10.0) self.interactions = ModuleList() for _ in range(num_interactions): block = InteractionBlock( hidden_channels, edge_channels, num_filters, cutoff, smooth ) self.interactions.append(block)
[docs] def forward( self, z: torch.Tensor, edge_index: torch.Tensor, edge_length: torch.Tensor, edge_attr: torch.Tensor, embed_node: bool = True, ) -> torch.Tensor: """Forward pass. Args: z: input tensor. edge_index: edge indices. edge_length: edge lengths. edge_attr: edge attributes. embed_node: whether to embed node. Returns: output tensor. """ if embed_node: assert z.dim() == 1 and z.dtype == torch.long h = self.embedding(z) else: h = z for interaction in self.interactions: h = h + interaction(h, edge_index, edge_length, edge_attr) return h
[docs]class GINEConv(MessagePassing): """GINECONV layer. Custom class of the graph isomorphism operator from the "How Powerful are Graph Neural Networks? https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation. """
[docs] def __init__( self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation: str = "softplus", **kwargs, ) -> None: """Construct a GINEConv layer. Args: mlp: MLP. eps: epsilon. train_eps: whether to train epsilon. activation: activation function. """ super(GINEConv, self).__init__(aggr="add", **kwargs) self.nn = mlp self.initial_eps = eps if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = None if train_eps: self.eps = torch.nn.Parameter(torch.Tensor([eps])) else: self.register_buffer("eps", torch.Tensor([eps]))
[docs] def forward( self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None, ) -> torch.Tensor: """Forward pass. Args: x: input tensor. edge_index: edge indices. edge_attr: edge attributes. size: size. Returns: output tensor. """ if isinstance(x, torch.Tensor): x = (x, x) # Node and edge feature dimensionalites need to match. if isinstance(edge_index, torch.Tensor): assert edge_attr is not None assert x[0].size(-1) == edge_attr.size(-1) elif isinstance(edge_index, SparseTensor): assert x[0].size(-1) == edge_index.size(-1) # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) x_r = x[1] if x_r is not None: out += (1 + self.eps) * x_r return self.nn(out)
[docs] def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: """Message function. Args: x_j: input tensor. edge_attr: edge attributes. Returns: message passing aggregation. """ if self.activation: return self.activation(x_j + edge_attr) else: return x_j + edge_attr
[docs] def __repr__(self): return "{}(nn={})".format(self.__class__.__name__, self.nn)
[docs]class GINEncoder(torch.nn.Module): """GIN encoder."""
[docs] def __init__( self, hidden_dim: int, num_convs: int = 3, activation: str = "relu", short_cut: bool = True, concat_hidden: bool = False, ) -> None: """Construct a GIN encoder. Args: hidden_dim: number of hidden channels. num_convs: number of convolutions. activation: activation function. short_cut: whether to use short cut. concat_hidden: whether to concatenate hidden. """ super().__init__() self.hidden_dim = hidden_dim self.num_convs = num_convs self.short_cut = short_cut self.concat_hidden = concat_hidden self.node_emb = nn.Embedding(100, hidden_dim) if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = None self.convs = nn.ModuleList() for i in range(self.num_convs): self.convs.append( GINEConv( MultiLayerPerceptron( hidden_dim, [hidden_dim, hidden_dim], activation=activation ), activation=activation, ) )
[docs] def forward( self, z: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor ) -> torch.Tensor: """Forward pass. args: z: input tensor. edge_index: edge indices. edge_attr: edge attributes. returns: graph with node feature. """ node_attr = self.node_emb(z) # (num_node, hidden) hiddens = [] conv_input = node_attr # (num_node, hidden) for conv_idx, conv in enumerate(self.convs): hidden = conv(conv_input, edge_index, edge_attr) if conv_idx < len(self.convs) - 1 and self.activation is not None: hidden = self.activation(hidden) assert hidden.shape == conv_input.shape if self.short_cut and hidden.shape == conv_input.shape: hidden += conv_input hiddens.append(hidden) conv_input = hidden if self.concat_hidden: node_feature = torch.cat(hiddens, dim=-1) else: node_feature = hiddens[-1] return node_feature
[docs]class MLPEdgeEncoder(Module): """MLP edge encoder."""
[docs] def __init__(self, hidden_dim: int = 100, activation: str = "relu") -> None: """Construct a MLP edge encoder. Args: hidden_dim: number of hidden channels. activation: activation function. """ super().__init__() self.hidden_dim = hidden_dim self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim) self.mlp = MultiLayerPerceptron( 1, [self.hidden_dim, self.hidden_dim], activation=activation )
@property def out_channels(self): return self.hidden_dim
[docs] def forward( self, edge_length: torch.Tensor, edge_type: torch.Tensor ) -> torch.Tensor: """ Args: edge_length: The length of edges. edge_type: The type of edges. Returns: the output representation of edges. """ d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) return d_emb * edge_attr # (num_edge, hidden)