#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from typing import Any, Callable, Dict, List, Tuple, Union, Optional
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
from rdkit.Chem.rdchem import Mol as RDMol
from gt4sd.frameworks.gflownet.dataloader.dataset import (
FlatRewards,
GFlowNetDataset,
GFlowNetTask,
RewardScalar,
)
from gt4sd.frameworks.gflownet.ml.models.mxmnet import (
HAR2EV,
MXMNet,
MXMNetConfig,
mol2graph,
)
PROPERTIES: List[str] = [
"rA",
"rB",
"rC",
"mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
]
[docs]class QM9Dataset(GFlowNetDataset):
"""QM9 dataset compatible with gflownet."""
[docs] def __init__(
self,
h5_file: str,
target: str = "gap",
properties: List[str] = PROPERTIES,
) -> None:
"""Initialize QM9 dataset.
Args:
h5_file: path to the h5 file containing the dataset.
target: target property to optimize and build the reward.
properties: list of properties to use as features.
"""
super().__init__(
h5_file=h5_file,
target=target,
properties=properties,
)
[docs]def thermometer(
v: torch.Tensor, n_bins: int = 50, vmin: int = 0, vmax: int = 1
) -> torch.Tensor:
"""Compute a thermometer reward using gap.
Args:
v: tensor of values to compute the reward.
n_bins: number of bins to use.
vmin: minimum value of the range.
vmax: maximum value of the range.
Returns:
tensor of the reward.
"""
bins = torch.linspace(vmin, vmax, n_bins)
gap = bins[1] - bins[0]
return (v[..., None] - bins.reshape((1,) * v.ndim + (-1,))).clamp(
0, gap.item()
) / gap
# define task
[docs]class QM9GapTask(GFlowNetTask):
"""QM9 task compatible with gflownet."""
[docs] def __init__(
self,
configuration: Dict[str, Any],
dataset: GFlowNetDataset,
reward_model: Optional[nn.Module] = None,
wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None,
):
"""Initialize QM9 task.
Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/tasks/qm9/qm9.py.
Args:
configuration: configuration of the task.
dataset: dataset to use for the task.
reward_model: model to use for the reward.
wrap_model: function to wrap the model.
"""
super().__init__(
configuration=configuration,
dataset=dataset,
reward_model=reward_model,
wrap_model=wrap_model,
)
[docs] def load_task_models(self) -> Dict[str, nn.Module]:
"""Loads the models for the task.
Returns:
dictionary of models.
"""
gap_model = MXMNet(MXMNetConfig(128, 6, 5.0))
try:
state_dict = torch.load("/ckpt/mxmnet_gap_model.pt")
gap_model.load_state_dict(state_dict)
except FileNotFoundError:
pass
gap_model.to(self.device)
# gap_model = self._wrap_model(gap_model)
return {"model_task": gap_model}
[docs] def cond_info_to_reward(
self, cond_info: Dict[str, torch.Tensor], _flat_reward: FlatRewards
) -> RewardScalar:
"""Compute the reward for a given conditional information.
Args:
cond_info: dictionary of conditional information.
_flat_reward: flat reward.
Returns:
reward scalar.
"""
if isinstance(_flat_reward, list):
flat_reward = torch.tensor(_flat_reward)
return RewardScalar(flat_reward ** cond_info["beta"])
[docs] def compute_flat_rewards(
self, mols: List[RDMol]
) -> Tuple[RewardScalar, torch.Tensor]:
"""Computes the flat rewards for a list of molecules.
Args:
mols: list of molecules.
Returns:
reward scalar and validity.
"""
graphs = [mol2graph(i) for i in mols] # type: ignore[attr-defined]
is_valid = torch.tensor([i is not None for i in graphs]).bool()
if not is_valid.any():
return RewardScalar(torch.zeros((0,))), is_valid
batch = gd.Batch.from_data_list([i for i in graphs if i is not None])
batch.to(self.device)
# sample model
preds = self.model["model_task"](batch)
preds = preds.reshape((-1,)).data.cpu() / HAR2EV # type: ignore[attr-defined]
preds[preds.isnan()] = 1
preds = self.flat_reward_transform(preds).clip(1e-4, 2)
return RewardScalar(preds), is_valid