#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Implementation of PaccMann^GP conditional generator."""
import json
import logging
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE
from paccmann_chemistry.utils.search import SamplingSearch
from paccmann_gp.affinity_minimization import AffinityMinimization
from paccmann_gp.callable_minimization import CallableMinimization
from paccmann_gp.combined_minimization import CombinedMinimization
from paccmann_gp.gp_optimizer import GPOptimizer
from paccmann_gp.mw_minimization import MWMinimization
from paccmann_gp.qed_minimization import QEDMinimization
from paccmann_gp.sa_minimization import SAMinimization
from paccmann_gp.smiles_generator import SmilesGenerator
from paccmann_predictor.models import MODEL_FACTORY
from pytoda.proteins.protein_language import ProteinLanguage
from pytoda.smiles.smiles_language import SMILESLanguage
from ....frameworks.torch import device_claim
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MINIMIZATION_FUNCTIONS = {
    "qed": QEDMinimization,
    "sa": SAMinimization,
    "molwt": MWMinimization,
    "affinity": AffinityMinimization,
    "callable": CallableMinimization,
}
[docs]class GPConditionalGenerator:
    """Conditional generator as implemented in https://doi.org/10.1021/acs.jcim.1c00889."""
[docs]    def __init__(
        self,
        resources_path: str,
        temperature: float = 1.4,
        generated_length: int = 100,
        batch_size: int = 32,
        limit: float = 5.0,
        acquisition_function: str = "EI",
        number_of_steps: int = 32,
        number_of_initial_points: int = 16,
        initial_point_generator: str = "random",
        seed: int = 42,
        number_of_optimization_rounds: int = 1,
        sampling_variance: float = 0.1,
        samples_for_evaluation: int = 4,
        maximum_number_of_sampling_steps: int = 32,
        device: Optional[Union[torch.device, str]] = None,
    ) -> None:
        """Initialize the conditional generator.
        Args:
            resources_path: directory where to find models and parameters.
            temperature: temperature parameter for the softmax sampling in decoding. Defaults to 1.4.
            generated_length: maximum length in tokens of the generated molcules (relates to the SMILES length). Defaults to 100.
            batch_size: batch size used for the generative model sampling. Defaults to 16.
            limit: hypercube limits in the latent space. Defaults to 5.0.
            acquisition_function: acquisition function used in the Gaussian process. Defaults to "EI". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html.
            number_of_steps: number of steps for an optmization round. Defaults to 32.
            number_of_initial_points: number of initial points evaluated. Defaults to 16.
            initial_point_generator: scheme to generate initial points. Defaults to "random". More details in https://scikit-optimize.github.io/stable/modules/generated/skopt.gp_minimize.html.
            seed: seed used for random number generation in the optimizer. Defaults to 42.
            number_of_optimization_rounds: maximum number of optimization rounds. Defaults to 1.
            sampling_variance: variance of the Gaussian noise applied during sampling from the optimal point. Defaults to 0.1.
            samples_for_evaluation: number of samples averaged for each minimization function evaluation. Defaults to 4.
            maximum_number_of_sampling_steps: maximum number of sampling steps in an optmization round. Defaults to 32.
            device: . Defaults to None, a.k.a, picking a default one ("gpu" if present, "cpu" otherwise).
        """
        # device
        self.device = device_claim(device)
        # setting sampling parameters
        self.temperature = temperature
        self.generated_length = generated_length
        self.batch_size = batch_size
        # setting VAE parameters
        self.svae_params = dict()
        with open(os.path.join(resources_path, "vae_model_params.json"), "r") as f:
            self.svae_params.update(json.load(f))
        smiles_language = SMILESLanguage.load(
            os.path.join(resources_path, "selfies_language.pkl")
        )
        # initialize encoder, decoder, testVAE, and GP_generator_MW
        self.gru_encoder = StackGRUEncoder(self.svae_params).to(self.device)
        self.gru_decoder = StackGRUDecoder(self.svae_params).to(self.device)
        self.gru_vae = TeacherVAE(self.gru_encoder, self.gru_decoder)
        self.gru_vae.load_state_dict(
            torch.load(
                os.path.join(resources_path, "vae_weights.pt"),
                map_location=self.device,
            )
        )
        self.gru_vae._associate_language(smiles_language)
        self.gru_vae.eval()
        self.smiles_generator = SmilesGenerator(
            self.gru_vae,
            search=SamplingSearch(temperature=self.temperature),
            generated_length=self.generated_length,
        )
        self.latent_dim = self.gru_decoder.latent_dim
        # setting affinity predictor parameters
        with open(os.path.join(resources_path, "mca_model_params.json")) as f:
            self.predictor_params = json.load(f)
        self.affinity_predictor = MODEL_FACTORY["bimodal_mca"](
            self.predictor_params
        ).to(self.device)
        self.affinity_predictor.load(
            os.path.join(resources_path, "mca_weights.pt"),
            map_location=self.device,
        )
        affinity_protein_language = ProteinLanguage.load(
            os.path.join(resources_path, "protein_language.pkl")
        )
        affinity_smiles_language = SMILESLanguage.load(
            os.path.join(resources_path, "smiles_language.pkl")
        )
        self.affinity_predictor._associate_language(affinity_smiles_language)
        self.affinity_predictor._associate_language(affinity_protein_language)
        self.affinity_predictor.eval()
        # setting optimizer parameters
        self.limit = limit
        self.acquisition_function = acquisition_function
        self.number_of_initial_points = number_of_initial_points
        if number_of_steps < self.number_of_initial_points:
            logger.warning(
                "number of initial points is larger than number of steps "
                f"({self.number_of_initial_points}/{number_of_steps}). "
                f"Resetting number of steps to {self.number_of_initial_points}."
            )
            self.number_of_steps = self.number_of_initial_points
        else:
            self.number_of_steps = number_of_steps
        self.initial_point_generator = initial_point_generator
        self.seed = None if seed == -1 else seed
        self.set_seed()
        self.number_of_optimization_rounds = number_of_optimization_rounds
        self.sampling_variance = sampling_variance
        self.samples_for_evaluation = samples_for_evaluation
        self.maximum_number_of_sampling_steps = maximum_number_of_sampling_steps 
[docs]    def target_to_minimization_function(
        self, target: Union[Dict[str, Dict[str, Any]], str]
    ) -> CombinedMinimization:
        """Use the target to configure a minimization function.
        Args:
            target: dictionary or JSON string describing the optimization target.
        Returns:
            a minimization function.
        """
        if isinstance(target, str):
            target_dictionary = json.loads(target)
        elif isinstance(target, dict):
            target_dictionary = deepcopy(target)
        else:
            raise ValueError(
                f"{target} of type {type(target)} is not supported: provide 'str' or 'Dict[str, Dict[str, Any]]'"
            )
        minimization_functions = []
        weights = []
        for minimization_function_name, parameters in target_dictionary.items():
            weight = 1.0
            if "weight" in parameters:
                weight = parameters.pop("weight")
            function_parameters = {
                **parameters,
                **{
                    "batch_size": self.samples_for_evaluation,
                    "smiles_decoder": self.smiles_generator,
                },
            }
            minimization_function = MINIMIZATION_FUNCTIONS[minimization_function_name]
            if minimization_function_name == "affinity":
                function_parameters["affinity_predictor"] = self.affinity_predictor
            minimization_functions.append(minimization_function(**function_parameters))
            weights.append(weight)
        return CombinedMinimization(
            minimization_functions=minimization_functions,
            batch_size=1,
            function_weights=weights,
        ) 
[docs]    def set_seed(self):
        """Set the seed for the random number generators."""
        if self.seed is None:
            return
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 
[docs]    def generate_batch(self, target: Any) -> List[str]:
        """Generate molecules given a target.
        Args:
            target: dictionary or JSON string describing the optimization target.
        Returns:
            a list of molecules as SMILES string.
        """
        # even if no seed is set, we want to avoid redundancy over multiple calls (using Knuth multiplicative hashing)
        opt_seed = (self.seed or 42) * 2654435761 % 2**32
        logger.info(f"configuring optimization for target: {target}")
        # target configuration
        self.target = target
        self.minimization_function = self.target_to_minimization_function(self.target)
        # optimizer configuration
        self.target_optimizer = GPOptimizer(self.minimization_function.evaluate)
        optimization_parameters = dict(
            dimensions=[(-self.limit, self.limit)] * self.latent_dim,
            acq_func=self.acquisition_function,
            n_calls=self.number_of_steps,
            n_initial_points=self.number_of_initial_points,
            initial_point_generator=self.initial_point_generator,
            random_state=opt_seed,
        )
        log_params = deepcopy(optimization_parameters)
        log_params["dimensions"] = np.mean(
            optimization_parameters["dimensions"]
        )  # type:ignore
        logger.info(f"running optimization with the following parameters: {log_params}")
        smiles_set = set()
        logger.info(
            f"running at most {self.number_of_optimization_rounds} optmization rounds"
        )
        for optimization_round in range(self.number_of_optimization_rounds):
            logger.info(f"starting round {optimization_round + 1}")
            optimization_parameters["random_state"] += optimization_round  # type:ignore
            res = self.target_optimizer.optimize(optimization_parameters)
            latent_point = torch.tensor([[res.x]])
            smiles_set_per_round = set()
            logger.info(f"starting sampling for {optimization_round + 1}")
            for _ in range(self.maximum_number_of_sampling_steps):
                generated_smiles = self.smiles_generator.generate_smiles(
                    latent_point.repeat(1, self.batch_size, 1)
                    + torch.cat(
                        (
                            torch.zeros(1, 1, self.latent_dim),
                            (self.sampling_variance**0.5)
                            * torch.randn(1, self.batch_size - 1, self.latent_dim),
                        ),
                        dim=1,
                    )
                )
                smiles_set_per_round.update(set(generated_smiles))
            smiles_set.update(smiles_set_per_round)
            logger.info(f"completing round {optimization_round + 1}")
        # Sort the molecules to ensure reproducibility
        mols = sorted(list([s for s in smiles_set if s]), key=len, reverse=True)
        logger.info(f"generated {len(mols)} molecules in the current run {mols}")
        return mols