#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import copy
from copy import deepcopy
import torch
from torch_geometric.data import Batch, Data
from torch_geometric.nn import radius, radius_graph
from torch_geometric.utils import dense_to_sparse, to_dense_adj
from torch_scatter import scatter_add
from torch_sparse import coalesce
[docs]def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):
h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]
h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)
return h_pair
def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):
"""
Args:
num_nodes: Number of atoms.
edge_index: Bond indices of the original graph.
edge_type: Bond types of the original graph.
order: Extension order.
Returns:
new_edge_index: Extended edge indices. new_edge_type: Extended edge types.
"""
def binarize(x):
return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
def get_higher_order_adj_matrix(adj, order):
"""
Args:
adj: (N, N)
type_mat: (N, N)
Returns:
Following attributes will be updated:
- edge_index
- edge_type
Following attributes will be added to the data object:
- bond_edge_index: Original edge_index.
"""
adj_mats = [
torch.eye(adj.size(0), dtype=torch.long, device=adj.device),
binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),
]
for i in range(2, order + 1):
adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))
order_mat = torch.zeros_like(adj)
for i in range(1, order + 1):
order_mat += (adj_mats[i] - adj_mats[i - 1]) * i
return order_mat
num_types = 22
# given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
# from rdkit.Chem.rdchem import BondType as BT
N = num_nodes
adj = to_dense_adj(edge_index).squeeze(0)
adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)
type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)
type_highorder = torch.where(
adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)
)
assert (type_mat * type_highorder == 0).all()
type_new = type_mat + type_highorder
new_edge_index, new_edge_type = dense_to_sparse(type_new)
_, edge_order = dense_to_sparse(adj_order)
# data.bond_edge_index = data.edge_index # Save original edges
new_edge_index, new_edge_type = coalesce(
new_edge_index, new_edge_type.long(), N, N
) # modify data
return new_edge_index, new_edge_type
def _extend_to_radius_graph(
pos,
edge_index,
edge_type,
cutoff,
batch,
unspecified_type_number=0,
is_sidechain=None,
):
assert edge_type.dim() == 1
N = pos.size(0)
bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N])) # type: ignore
if is_sidechain is None:
rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)
else:
# fetch sidechain and its batch index
is_sidechain = is_sidechain.bool()
dummy_index = torch.arange(pos.size(0), device=pos.device)
sidechain_pos = pos[is_sidechain]
sidechain_index = dummy_index[is_sidechain]
sidechain_batch = batch[is_sidechain]
assign_index = radius(
x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch
)
r_edge_index_x = assign_index[1]
r_edge_index_y = assign_index[0]
r_edge_index_y = sidechain_index[r_edge_index_y]
rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)
rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)
rgraph_edge_index = torch.cat(
(rgraph_edge_index1, rgraph_edge_index2), dim=-1
) # (2, 2E)
# delete self loop
rgraph_edge_index = rgraph_edge_index[
:, (rgraph_edge_index[0] != rgraph_edge_index[1])
]
rgraph_adj = torch.sparse.LongTensor( # type: ignore
rgraph_edge_index,
torch.ones(rgraph_edge_index.size(1)).long().to(pos.device)
* unspecified_type_number,
torch.Size([N, N]),
)
composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)
new_edge_index = composed_adj.indices()
new_edge_type = composed_adj.values().long()
return new_edge_index, new_edge_type
[docs]def extend_graph_order_radius(
num_nodes,
pos,
edge_index,
edge_type,
batch,
order=3,
cutoff=10.0,
extend_order=True,
extend_radius=True,
is_sidechain=None,
):
if extend_order:
edge_index, edge_type = _extend_graph_order(
num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order
)
if extend_radius:
edge_index, edge_type = _extend_to_radius_graph(
pos=pos,
edge_index=edge_index,
edge_type=edge_type,
cutoff=cutoff,
batch=batch,
is_sidechain=is_sidechain,
)
return edge_index, edge_type
[docs]def get_distance(pos, edge_index):
return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
[docs]def graph_field_network(score_d, pos, edge_index, edge_length):
"""
Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations
5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf
"""
N = pos.size(0)
dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)
score_pos = scatter_add(
dd_dr * score_d, edge_index[0], dim=0, dim_size=N
) + scatter_add(
-dd_dr * score_d, edge_index[1], dim=0, dim_size=N
) # (N, 3)
return score_pos
[docs]def clip_norm(vec, limit, p=2):
norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
return vec * denom
[docs]def is_local_edge(edge_type):
return edge_type > 0
[docs]def repeat_data(data: Data, num_repeat: int) -> Batch:
"""
Args:
data: An `torch_geometric.data.Data` object.
Returns:
batch: A copy of `data` repetead `num_repeat` times.
"""
datas = [copy.deepcopy(data) for i in range(num_repeat)]
return Batch.from_data_list(datas)
[docs]def repeat_batch(batch: Batch, num_repeat: int) -> Batch:
"""
Args:
batch: An `torch_geometric.data.Batch` object.
Returns:
batch: A copy of `batch` repetead `num_repeat` times.
"""
datas = batch.to_data_list()
new_data = []
for i in range(num_repeat):
new_data += copy.deepcopy(datas)
return Batch.from_data_list(new_data)
[docs]def set_rdmol_positions(rdkit_mol, pos):
"""
Args:
rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
pos: (N_atoms, 3)
Returns:
mol: A copy of `rdkit_mol` with the positions set to `pos`.
"""
mol = deepcopy(rdkit_mol)
set_rdmol_positions_(mol, pos)
return mol
[docs]def set_rdmol_positions_(mol, pos):
"""
Args:
rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
pos: (N_atoms, 3)
Returns:
mol: A copy of `rdkit_mol` with the positions set to `pos`.
"""
for i in range(pos.shape[0]):
mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
return mol