#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Implementation of PaccMann^GP conditional generator."""
import json
import logging
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE
from paccmann_chemistry.utils.search import SamplingSearch
from paccmann_gp.affinity_minimization import AffinityMinimization
from paccmann_gp.callable_minimization import CallableMinimization
from paccmann_gp.combined_minimization import CombinedMinimization
from paccmann_gp.gp_optimizer import GPOptimizer
from paccmann_gp.mw_minimization import MWMinimization
from paccmann_gp.qed_minimization import QEDMinimization
from paccmann_gp.sa_minimization import SAMinimization
from paccmann_gp.smiles_generator import SmilesGenerator
from paccmann_predictor.models import MODEL_FACTORY
from pytoda.proteins.protein_language import ProteinLanguage
from pytoda.smiles.smiles_language import SMILESLanguage
from ....frameworks.torch import device_claim
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MINIMIZATION_FUNCTIONS = {
"qed": QEDMinimization,
"sa": SAMinimization,
"molwt": MWMinimization,
"affinity": AffinityMinimization,
"callable": CallableMinimization,
}
[docs]class GPConditionalGenerator:
"""Conditional generator as implemented in https://doi.org/10.1021/acs.jcim.1c00889."""
[docs] def __init__(
self,
resources_path: str,
temperature: float = 1.4,
generated_length: int = 100,
batch_size: int = 32,
limit: float = 5.0,
acquisition_function: str = "EI",
number_of_steps: int = 32,
number_of_initial_points: int = 16,
initial_point_generator: str = "random",
seed: int = 42,
number_of_optimization_rounds: int = 1,
sampling_variance: float = 0.1,
samples_for_evaluation: int = 4,
maximum_number_of_sampling_steps: int = 32,
device: Optional[Union[torch.device, str]] = None,
) -> None:
"""Initialize the conditional generator.
Args:
resources_path: directory where to find models and parameters.
temperature: temperature parameter for the softmax sampling in decoding. Defaults to 1.4.
generated_length: maximum length in tokens of the generated molcules (relates to the SMILES length). Defaults to 100.
batch_size: batch size used for the generative model sampling. Defaults to 16.
limit: hypercube limits in the latent space. Defaults to 5.0.
acquisition_function: acquisition function used in the Gaussian process. Defaults to "EI". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html.
number_of_steps: number of steps for an optmization round. Defaults to 32.
number_of_initial_points: number of initial points evaluated. Defaults to 16.
initial_point_generator: scheme to generate initial points. Defaults to "random". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html.
seed: seed used for random number generation in the optimizer. Defaults to 42.
number_of_optimization_rounds: maximum number of optimization rounds. Defaults to 1.
sampling_variance: variance of the Gaussian noise applied during sampling from the optimal point. Defaults to 0.1.
samples_for_evaluation: number of samples averaged for each minimization function evaluation. Defaults to 4.
maximum_number_of_sampling_steps: maximum number of sampling steps in an optmization round. Defaults to 32.
device: . Defaults to None, a.k.a, picking a default one ("gpu" if present, "cpu" otherwise).
"""
# device
self.device = device_claim(device)
# setting sampling parameters
self.temperature = temperature
self.generated_length = generated_length
self.batch_size = batch_size
# setting VAE parameters
self.svae_params = dict()
with open(os.path.join(resources_path, "vae_model_params.json"), "r") as f:
self.svae_params.update(json.load(f))
smiles_language = SMILESLanguage.load(
os.path.join(resources_path, "selfies_language.pkl")
)
# initialize encoder, decoder, testVAE, and GP_generator_MW
self.gru_encoder = StackGRUEncoder(self.svae_params).to(self.device)
self.gru_decoder = StackGRUDecoder(self.svae_params).to(self.device)
self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder)
self.gru_vae.load_state_dict(
torch.load(
os.path.join(resources_path, "vae_weights.pt"),
map_location=self.device,
)
)
self.gru_vae._associate_language(smiles_language)
self.gru_vae.eval()
self.smiles_generator = SmilesGenerator(
self.gru_vae,
search=SamplingSearch(temperature=self.temperature),
generated_length=self.generated_length,
)
self.latent_dim = self.gru_decoder.latent_dim
# setting affinity predictor parameters
with open(os.path.join(resources_path, "mca_model_params.json")) as f:
self.predictor_params = json.load(f)
self.affinity_predictor = MODEL_FACTORY["bimodal_mca"](
self.predictor_params
).to(self.device)
self.affinity_predictor.load(
os.path.join(resources_path, "mca_weights.pt"),
map_location=self.device,
)
affinity_protein_language = ProteinLanguage.load(
os.path.join(resources_path, "protein_language.pkl")
)
affinity_smiles_language = SMILESLanguage.load(
os.path.join(resources_path, "smiles_language.pkl")
)
self.affinity_predictor._associate_language(affinity_smiles_language)
self.affinity_predictor._associate_language(affinity_protein_language)
self.affinity_predictor.eval()
# setting optimizer parameters
self.limit = limit
self.acquisition_function = acquisition_function
self.number_of_initial_points = number_of_initial_points
if number_of_steps < self.number_of_initial_points:
logger.warning(
"number of initial points is larger than number of steps "
f"({self.number_of_initial_points}/{number_of_steps}). "
f"Resetting number of steps to {self.number_of_initial_points}."
)
self.number_of_steps = self.number_of_initial_points
else:
self.number_of_steps = number_of_steps
self.initial_point_generator = initial_point_generator
self.seed = None if seed == -1 else seed
self.set_seed()
self.number_of_optimization_rounds = number_of_optimization_rounds
self.sampling_variance = sampling_variance
self.samples_for_evaluation = samples_for_evaluation
self.maximum_number_of_sampling_steps = maximum_number_of_sampling_steps
[docs] def target_to_minimization_function(
self, target: Union[Dict[str, Dict[str, Any]], str]
) -> CombinedMinimization:
"""Use the target to configure a minimization function.
Args:
target: dictionary or JSON string describing the optimization target.
Returns:
a minimization function.
"""
if isinstance(target, str):
target_dictionary = json.loads(target)
elif isinstance(target, dict):
target_dictionary = deepcopy(target)
else:
raise ValueError(
f"{target} of type {type(target)} is not supported: provide 'str' or 'Dict[str, Dict[str, Any]]'"
)
minimization_functions = []
weights = []
for minimization_function_name, parameters in target_dictionary.items():
weight = 1.0
if "weight" in parameters:
weight = parameters.pop("weight")
function_parameters = {
**parameters,
**{
"batch_size": self.samples_for_evaluation,
"smiles_decoder": self.smiles_generator,
},
}
minimization_function = MINIMIZATION_FUNCTIONS[minimization_function_name]
if minimization_function_name == "affinity":
function_parameters["affinity_predictor"] = self.affinity_predictor
minimization_functions.append(minimization_function(**function_parameters))
weights.append(weight)
return CombinedMinimization(
minimization_functions=minimization_functions,
batch_size=1,
function_weights=weights,
)
[docs] def set_seed(self):
"""Set the seed for the random number generators."""
if self.seed is None:
return
np.random.seed(self.seed)
torch.manual_seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(self.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
[docs] def generate_batch(self, target: Any) -> List[str]:
"""Generate molecules given a target.
Args:
target: dictionary or JSON string describing the optimization target.
Returns:
a list of molecules as SMILES string.
"""
# even if no seed is set, we want to avoid redundancy over multiple calls (using Knuth multiplicative hashing)
opt_seed = (self.seed or 42) * 2654435761 % 2**32
logger.info(f"configuring optimization for target: {target}")
# target configuration
self.target = target
self.minimization_function = self.target_to_minimization_function(self.target)
# optimizer configuration
self.target_optimizer = GPOptimizer(self.minimization_function.evaluate)
optimization_parameters = dict(
dimensions=[(-self.limit, self.limit)] * self.latent_dim,
acq_func=self.acquisition_function,
n_calls=self.number_of_steps,
n_initial_points=self.number_of_initial_points,
initial_point_generator=self.initial_point_generator,
random_state=opt_seed,
)
log_params = deepcopy(optimization_parameters)
log_params["dimensions"] = np.mean(
optimization_parameters["dimensions"]
) # type:ignore
logger.info(f"running optimization with the following parameters: {log_params}")
smiles_set = set()
logger.info(
f"running at most {self.number_of_optimization_rounds} optmization rounds"
)
for optimization_round in range(self.number_of_optimization_rounds):
logger.info(f"starting round {optimization_round + 1}")
optimization_parameters["random_state"] += optimization_round # type:ignore
res = self.target_optimizer.optimize(optimization_parameters)
latent_point = torch.tensor([[res.x]])
smiles_set_per_round = set()
logger.info(f"starting sampling for {optimization_round + 1}")
for _ in range(self.maximum_number_of_sampling_steps):
generated_smiles = self.smiles_generator.generate_smiles(
latent_point.repeat(1, self.batch_size, 1)
+ torch.cat(
(
torch.zeros(1, 1, self.latent_dim),
(self.sampling_variance**0.5)
* torch.randn(1, self.batch_size - 1, self.latent_dim),
),
dim=1,
)
)
smiles_set_per_round.update(set(generated_smiles))
smiles_set.update(smiles_set_per_round)
logger.info(f"completing round {optimization_round + 1}")
# Sort the molecules to ensure reproducibility
mols = sorted(list([s for s in smiles_set if s]), key=len, reverse=True)
logger.info(f"generated {len(mols)} molecules in the current run {mols}")
return mols