Source code for gt4sd.frameworks.cgcnn.model

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Model module."""

from __future__ import division, print_function

from typing import Any, Dict

import torch
import torch.nn as nn


[docs]class ConvLayer(nn.Module): """Convolutional operation on graphs."""
[docs] def __init__(self, atom_fea_len: int, nbr_fea_len: int): """Initialize ConvLayer. Args: atom_fea_len: int Number of atom hidden features. nbr_fea_len: int Number of bond features. """ super(ConvLayer, self).__init__() self.atom_fea_len = atom_fea_len self.nbr_fea_len = nbr_fea_len self.fc_full = nn.Linear( 2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len ) self.sigmoid = nn.Sigmoid() self.softplus1 = nn.Softplus() self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len) self.bn2 = nn.BatchNorm1d(self.atom_fea_len) self.softplus2 = nn.Softplus()
[docs] def forward( self, atom_in_fea: torch.Tensor, nbr_fea: torch.Tensor, nbr_fea_idx: torch.LongTensor, ) -> torch.Tensor: """Forward pass. N: Total number of atoms in the batch. M: Max number of neighbors. Args: atom_in_fea: Variable(torch.Tensor) shape (N, atom_fea_len) Atom hidden features before convolution. nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors. nbr_fea_idx: torch.LongTensor shape (N, M) Indices of M neighbors of each atom. Returns: atom_out_fea: nn.Variable shape (N, atom_fea_len) Atom hidden features after convolution. """ # TODO will there be problems with the index zero padding? N, M = nbr_fea_idx.shape # convolution atom_nbr_fea = atom_in_fea[nbr_fea_idx, :] total_nbr_fea = torch.cat( [ atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len), atom_nbr_fea, nbr_fea, ], dim=2, ) total_gated_fea = self.fc_full(total_nbr_fea) total_gated_fea = self.bn1( total_gated_fea.view(-1, self.atom_fea_len * 2) ).view(N, M, self.atom_fea_len * 2) nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2) nbr_filter = self.sigmoid(nbr_filter) nbr_core = self.softplus1(nbr_core) nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1) nbr_sumed = self.bn2(nbr_sumed) out = self.softplus2(atom_in_fea + nbr_sumed) return out
[docs]class CrystalGraphConvNet(nn.Module): """Create a crystal graph convolutional neural network for predicting total material properties."""
[docs] def __init__( self, orig_atom_fea_len: int, nbr_fea_len: int, atom_fea_len: int = 64, n_conv: int = 3, h_fea_len: int = 128, n_h: int = 1, classification: bool = False, ): """Initialize CrystalGraphConvNet. Args: orig_atom_fea_len: int Number of atom features in the input. nbr_fea_len: int Number of bond features. atom_fea_len: int Number of hidden atom features in the convolutional layers. n_conv: int Number of convolutional layers. h_fea_len: int Number of hidden features after pooling. n_h: int Number of hidden layers after pooling. """ super(CrystalGraphConvNet, self).__init__() self.classification = classification self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len) self.convs = nn.ModuleList( [ ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len) for _ in range(n_conv) ] ) self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len) self.conv_to_fc_softplus = nn.Softplus() if n_h > 1: self.fcs = nn.ModuleList( [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)] ) self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)]) if self.classification: self.fc_out = nn.Linear(h_fea_len, 2) else: self.fc_out = nn.Linear(h_fea_len, 1) if self.classification: self.logsoftmax = nn.LogSoftmax(dim=1) self.dropout = nn.Dropout()
[docs] def forward( self, atom_fea: torch.Tensor, nbr_fea: torch.Tensor, nbr_fea_idx: torch.LongTensor, crystal_atom_idx: torch.LongTensor, ) -> torch.Tensor: """Forward pass. N: Total number of atoms in the batch. M: Max number of neighbors. N0: Total number of crystals in the batch. Args: atom_fea: Variable(torch.Tensor) shape (N, orig_atom_fea_len) Atom features from atom type. nbr_fea: Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors. nbr_fea_idx: torch.LongTensor shape (N, M) Indices of M neighbors of each atom. crystal_atom_idx: list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx. Returns: prediction: nn.Variable shape (N, ) Atom hidden features after convolution. """ atom_fea = self.embedding(atom_fea) for conv_func in self.convs: atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx) crys_fea = self.pooling(atom_fea, crystal_atom_idx) crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea)) crys_fea = self.conv_to_fc_softplus(crys_fea) if self.classification: crys_fea = self.dropout(crys_fea) if hasattr(self, "fcs") and hasattr(self, "softpluses"): for fc, softplus in zip(self.fcs, self.softpluses): crys_fea = softplus(fc(crys_fea)) out = self.fc_out(crys_fea) if self.classification: out = self.logsoftmax(out) return out
[docs] def pooling( self, atom_fea: torch.Tensor, crystal_atom_idx: torch.LongTensor ) -> torch.Tensor: """Pooling the atom features to crystal features. N: Total number of atoms in the batch. N0: Total number of crystals in the batch. Args: atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len) Atom feature vectors of the batch. crystal_atom_idx: list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx. """ assert ( sum([len(idx_map) for idx_map in crystal_atom_idx]) == atom_fea.data.shape[0] ) summed_fea = [ torch.mean(atom_fea[idx_map], dim=0, keepdim=True) for idx_map in crystal_atom_idx ] return torch.cat(summed_fea, dim=0)
[docs]class Normalizer: """Normalize a Tensor and restore it later."""
[docs] def __init__(self, tensor: torch.Tensor): """tensor is taken as a sample to calculate the mean and std.""" self.mean = torch.mean(tensor) self.std = torch.std(tensor)
[docs] def norm(self, tensor: torch.Tensor) -> torch.Tensor: """Noramlize a tensor. Args: tensor: tensor to be normalized. Returns: normalized tensor. """ return (tensor - self.mean) / self.std
[docs] def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor: """Denormalized tensor. Args: tensor: tensor to be denormalized: Returns: denormalized tensor. """ return normed_tensor * self.std + self.mean
[docs] def state_dict(self) -> Dict[str, torch.Tensor]: """Return the state dict of normalizer. Returns: dictionary including the used mean and std values. """ return {"mean": self.mean, "std": self.std}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Return the state dict of normalizer. Args: mean: mean value to be used for the normalization. std: std value to be used for the normalization. """ self.mean = state_dict["mean"] self.std = state_dict["std"]