Source code for gt4sd.algorithms.generation.diffusion.geodiff.model.utils

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import copy
from copy import deepcopy

import torch
from torch_geometric.data import Batch, Data
from torch_geometric.nn import radius, radius_graph
from torch_geometric.utils import dense_to_sparse, to_dense_adj
from torch_scatter import scatter_add
from torch_sparse import coalesce


[docs]def assemble_atom_pair_feature(node_attr, edge_index, edge_attr): h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]] h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H) return h_pair
def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): """ Args: num_nodes: Number of atoms. edge_index: Bond indices of the original graph. edge_type: Bond types of the original graph. order: Extension order. Returns: new_edge_index: Extended edge indices. new_edge_type: Extended edge types. """ def binarize(x): return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) def get_higher_order_adj_matrix(adj, order): """ Args: adj: (N, N) type_mat: (N, N) Returns: Following attributes will be updated: - edge_index - edge_type Following attributes will be added to the data object: - bond_edge_index: Original edge_index. """ adj_mats = [ torch.eye(adj.size(0), dtype=torch.long, device=adj.device), binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)), ] for i in range(2, order + 1): adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1])) order_mat = torch.zeros_like(adj) for i in range(1, order + 1): order_mat += (adj_mats[i] - adj_mats[i - 1]) * i return order_mat num_types = 22 # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} # from rdkit.Chem.rdchem import BondType as BT N = num_nodes adj = to_dense_adj(edge_index).squeeze(0) adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) type_highorder = torch.where( adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order) ) assert (type_mat * type_highorder == 0).all() type_new = type_mat + type_highorder new_edge_index, new_edge_type = dense_to_sparse(type_new) _, edge_order = dense_to_sparse(adj_order) # data.bond_edge_index = data.edge_index # Save original edges new_edge_index, new_edge_type = coalesce( new_edge_index, new_edge_type.long(), N, N ) # modify data return new_edge_index, new_edge_type def _extend_to_radius_graph( pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None, ): assert edge_type.dim() == 1 N = pos.size(0) bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) # type: ignore if is_sidechain is None: rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) else: # fetch sidechain and its batch index is_sidechain = is_sidechain.bool() dummy_index = torch.arange(pos.size(0), device=pos.device) sidechain_pos = pos[is_sidechain] sidechain_index = dummy_index[is_sidechain] sidechain_batch = batch[is_sidechain] assign_index = radius( x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch ) r_edge_index_x = assign_index[1] r_edge_index_y = assign_index[0] r_edge_index_y = sidechain_index[r_edge_index_y] rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E) rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E) rgraph_edge_index = torch.cat( (rgraph_edge_index1, rgraph_edge_index2), dim=-1 ) # (2, 2E) # delete self loop rgraph_edge_index = rgraph_edge_index[ :, (rgraph_edge_index[0] != rgraph_edge_index[1]) ] rgraph_adj = torch.sparse.LongTensor( # type: ignore rgraph_edge_index, torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number, torch.Size([N, N]), ) composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) new_edge_index = composed_adj.indices() new_edge_type = composed_adj.values().long() return new_edge_index, new_edge_type
[docs]def extend_graph_order_radius( num_nodes, pos, edge_index, edge_type, batch, order=3, cutoff=10.0, extend_order=True, extend_radius=True, is_sidechain=None, ): if extend_order: edge_index, edge_type = _extend_graph_order( num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order ) if extend_radius: edge_index, edge_type = _extend_to_radius_graph( pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain, ) return edge_index, edge_type
[docs]def get_distance(pos, edge_index): return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
[docs]def graph_field_network(score_d, pos, edge_index, edge_length): """ Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf """ N = pos.size(0) dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) score_pos = scatter_add( dd_dr * score_d, edge_index[0], dim=0, dim_size=N ) + scatter_add( -dd_dr * score_d, edge_index[1], dim=0, dim_size=N ) # (N, 3) return score_pos
[docs]def clip_norm(vec, limit, p=2): norm = torch.norm(vec, dim=-1, p=2, keepdim=True) denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) return vec * denom
[docs]def is_local_edge(edge_type): return edge_type > 0
[docs]def repeat_data(data: Data, num_repeat: int) -> Batch: """ Args: data: An `torch_geometric.data.Data` object. Returns: batch: A copy of `data` repetead `num_repeat` times. """ datas = [copy.deepcopy(data) for i in range(num_repeat)] return Batch.from_data_list(datas)
[docs]def repeat_batch(batch: Batch, num_repeat: int) -> Batch: """ Args: batch: An `torch_geometric.data.Batch` object. Returns: batch: A copy of `batch` repetead `num_repeat` times. """ datas = batch.to_data_list() new_data = [] for i in range(num_repeat): new_data += copy.deepcopy(datas) return Batch.from_data_list(new_data)
[docs]def set_rdmol_positions(rdkit_mol, pos): """ Args: rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. pos: (N_atoms, 3) Returns: mol: A copy of `rdkit_mol` with the positions set to `pos`. """ mol = deepcopy(rdkit_mol) set_rdmol_positions_(mol, pos) return mol
[docs]def set_rdmol_positions_(mol, pos): """ Args: rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. pos: (N_atoms, 3) Returns: mol: A copy of `rdkit_mol` with the positions set to `pos`. """ for i in range(pos.shape[0]): mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) return mol