Source code for gt4sd.frameworks.gflownet.envs.mol_building_env

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from typing import List, Tuple

import networkx as nx
import numpy as np
import rdkit.Chem as Chem
import torch
import torch_geometric.data as gd
from rdkit.Chem import Mol
from rdkit.Chem.rdchem import BondType, ChiralType

from ..envs.graph_building_env import (
    Graph,
    GraphAction,
    GraphActionType,
    GraphBuildingEnvContext,
)


[docs]class MolBuildingEnvContext(GraphBuildingEnvContext): """A context environment for building molecular graphs."""
[docs] def __init__( self, atoms: List[str] = ["H", "C", "N", "O", "F"], num_cond_dim: int = 32, device="cpu", ) -> None: """Initialize a generic context environment for molecules. A specification of what is being generated for a GraphBuildingEnv. This context specifies how to create molecules atom-by-atom (and attribute-by-attribute). Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/envs Args: atoms: basic building blocks. In principle we can infer this information from the dataset. num_cond_dim: number of conditional dimensions.s device: device to use (cpu, cuda). """ # idx 0 has to coincide with the default value self.atom_attr_values = { "v": atoms, "chi": [ ChiralType.CHI_UNSPECIFIED, ChiralType.CHI_TETRAHEDRAL_CW, ChiralType.CHI_TETRAHEDRAL_CCW, ], "charge": [0, 1, -1], "expl_H": list(range(4)), # TODO: check what is the actual range of this "no_impl": [False, True], } self.atom_attr_defaults = { k: self.atom_attr_values[k][0] for k in self.atom_attr_values } # The size of the input vector for each atom self.atom_attr_size = sum(len(i) for i in self.atom_attr_values.values()) self.atom_attrs = sorted(self.atom_attr_values.keys()) # The beginning position within the input vector of each attribute self.atom_attr_slice = [0] + list( np.cumsum([len(self.atom_attr_values[i]) for i in self.atom_attrs]) ) # The beginning position within the logit vector of each attribute self.atom_attr_logit_slice = { k: s for k, s in zip( self.atom_attrs, [0] + list( np.cumsum( [len(self.atom_attr_values[i]) - 1 for i in self.atom_attrs] ) ), ) } # The attribute and value each logit dimension maps back to self.atom_attr_logit_map = [ (k, v) for k in self.atom_attrs if k != "v" # index 0 is skipped because it is the default value for v in self.atom_attr_values[k][1:] ] self.bond_attr_values = { "type": [ BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, ], } self.bond_attr_defaults = { k: self.bond_attr_values[k][0] for k in self.bond_attr_values } self.bond_attr_size = sum(len(i) for i in self.bond_attr_values.values()) self.bond_attrs = sorted(self.bond_attr_values.keys()) self.bond_attr_slice = [0] + list( np.cumsum([len(self.bond_attr_values[i]) for i in self.bond_attrs]) ) self.bond_attr_logit_slice = { k: s for k, s in zip( self.bond_attrs, [0] + list( np.cumsum( [len(self.bond_attr_values[i]) - 1 for i in self.bond_attrs] ) ), ) } self.bond_attr_logit_map = [ (k, v) for k in self.bond_attrs for v in self.bond_attr_values[k][1:] ] # These values are used by Models to know how many inputs/logits to produce self.num_new_node_values = len(atoms) self.num_node_attr_logits = len(self.atom_attr_logit_map) self.num_node_dim = self.atom_attr_size + 1 self.num_edge_attr_logits = len(self.bond_attr_logit_map) self.num_edge_dim = self.bond_attr_size self.num_cond_dim = num_cond_dim # Order in which models have to output logits self.action_type_order = [ GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.SetNodeAttr, GraphActionType.AddEdge, GraphActionType.SetEdgeAttr, ] self.device = device
[docs] def aidx_to_graph_action( self, g: gd.Data, action_idx: Tuple[int, int, int] ) -> GraphAction: """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction. Args: g: The graph to act on. action_idx: The action index. Raises: ValueError: If the action index is invalid. Returns: The action corresponding to the action index. """ act_type, act_row, act_col = [int(i) for i in action_idx] t = self.action_type_order[act_type] if t is GraphActionType.Stop: return GraphAction(t) elif t is GraphActionType.AddNode: return GraphAction( t, source=act_row, value=self.atom_attr_values["v"][act_col] ) elif t is GraphActionType.SetNodeAttr: attr, val = self.atom_attr_logit_map[act_col] return GraphAction(t, source=act_row, attr=attr, value=val) elif t is GraphActionType.AddEdge: a, b = g.non_edge_index[:, act_row] return GraphAction(t, source=a.item(), target=b.item()) # Edges are duplicated to get undirected GNN, deduplicated for logits elif t is GraphActionType.SetEdgeAttr: a, b = g.edge_index[:, act_row * 2] attr, val = self.bond_attr_logit_map[act_col] return GraphAction( t, source=a.item(), target=b.item(), attr=attr, value=val ) else: raise ValueError(f"Unknown action type: {t}")
[docs] def graph_action_to_aidx( self, g: gd.Data, action: GraphAction ) -> Tuple[int, int, int]: """Translate a GraphAction to an index tuple. Args: g: The graph to act on. action: The action to translate. Returns: The index corresponding to the action. """ if action.action is GraphActionType.Stop: row = 0 col = 0 elif action.action is GraphActionType.AddNode: row = action.source # type: ignore col = self.atom_attr_values["v"].index(action.value) elif action.action is GraphActionType.SetNodeAttr: row = action.source # type: ignore # - 1 because the default is index 0 col = ( self.atom_attr_values[action.attr].index(action.value) # type: ignore - 1 + self.atom_attr_logit_slice[action.attr] # type: ignore ) elif action.action is GraphActionType.AddEdge: # Here we have to retrieve the index in non_edge_index of an edge (s,t) # that's also possibly in the reverse order (t,s). # That's definitely not too efficient, can we do better? row = ( ( g.non_edge_index.T == torch.tensor([(action.source, action.target)]) ).prod(1) + ( g.non_edge_index.T == torch.tensor([(action.target, action.source)]) ).prod(1) ).argmax() col = 0 elif action.action is GraphActionType.SetEdgeAttr: # Here the edges are duplicated, both (i,j) and (j,i) are in edge_index # so no need for a double check. # row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) + # (g.edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)).argmax() row = ( (g.edge_index.T == torch.tensor([(action.source, action.target)])) .prod(1) .argmax() ) # Because edges are duplicated but logits aren't, divide by two row = row.div(2, rounding_mode="floor") # type: ignore col = ( self.bond_attr_values[action.attr].index(action.value) # type: ignore - 1 + self.bond_attr_logit_slice[action.attr] # type: ignore ) type_idx = self.action_type_order.index(action.action) row = int(row) col = int(col) return (type_idx, row, col)
[docs] def graph_to_data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance. Args: g: Networkx Graph to convert. Returns: torch geometric Data instance. """ x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim)) x[0, -1] = len(g.nodes) == 0 for i, n in enumerate(g.nodes): ad = g.nodes[n] for k, sl in zip(self.atom_attrs, self.atom_attr_slice): idx = self.atom_attr_values[k].index(ad[k]) if k in ad else 0 x[i, sl + idx] = 1 edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): idx = self.bond_attr_values[k].index(ad[k]) if k in ad else 0 edge_attr[i * 2, sl + idx] = 1 edge_attr[i * 2 + 1, sl + idx] = 1 edge_index = torch.tensor( [e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long ) edge_index = edge_index.reshape((-1, 2)).T gc = nx.complement(g) non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long) if len(non_edge_index.shape) == 2: non_edge_index = non_edge_index.T non_edge_index = non_edge_index.reshape((2, -1)) return gd.Data(x, edge_index, edge_attr, non_edge_index=non_edge_index)
[docs] def collate(self, graphs: List[gd.Data]): """Batch Data instances. Args: graphs: List of Data instances. Returns: Batch of Data instances. """ return gd.Batch.from_data_list( graphs, follow_batch=["edge_index", "non_edge_index"] )
[docs] def mol_to_graph(self, mol: Mol) -> Graph: """Convert an RDMol to a Graph. Args: mol: RDKit molecule format. Returns: Graph format. """ g = Graph() # Only set an attribute tag if it is not the default attribute for a in mol.GetAtoms(): attrs = { "chi": a.GetChiralTag(), "charge": a.GetFormalCharge(), "expl_H": a.GetNumExplicitHs(), "no_impl": a.GetNoImplicit(), } g.add_node( a.GetIdx(), v=a.GetSymbol(), **{ attr: val for attr, val in attrs.items() if val != self.atom_attr_defaults[attr] }, ) for b in mol.GetBonds(): attrs = {"type": b.GetBondType()} g.add_edge( b.GetBeginAtomIdx(), b.GetEndAtomIdx(), **{ attr: val for attr, val in attrs.items() if val != self.bond_attr_defaults[attr] }, ) return g
[docs] def graph_to_mol(self, g: Graph) -> Mol: """Convert a Graph to an RDKit molecule. Args: g: Graph format. Returns: RDKit molecule format. """ mp = Chem.RWMol() mp.BeginBatchEdit() for i in range(len(g.nodes)): d = g.nodes[i] a = Chem.Atom(d["v"]) if "chi" in d: a.SetChiralTag(d["chi"]) if "charge" in d: a.SetFormalCharge(d["charge"]) if "expl_H" in d: a.SetNumExplicitHs(d["expl_H"]) if "no_impl" in d: a.SetNoImplicit(d["no_impl"]) mp.AddAtom(a) for e in g.edges: d = g.edges[e] mp.AddBond(e[0], e[1], d.get("type", BondType.SINGLE)) mp.CommitBatchEdit() Chem.SanitizeMol(mp) return mp
[docs] def is_sane(self, g: Graph) -> bool: """Check if a graph is sane. Args: g: Graph format. Returns: True if sane, False otherwise. """ try: mol = self.graph_to_mol(g) assert Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is not None except Exception: return False if mol is None: return False return True