#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import copy
from itertools import count
from typing import Any, Dict, List, Tuple, Union, Optional
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor
from torch_scatter import scatter
from ..envs.graph_building_env import (
    Graph,
    GraphAction,
    GraphActionCategorical,
    GraphActionType,
    GraphBuildingEnv,
    GraphBuildingEnvContext,
    generate_forward_trajectory,
)
[docs]class TrajectoryBalanceModel(nn.Module):
    """Generic model compatible with trajectory balance."""
[docs]    def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]:
        """Run forward pass.
        Args:
            batch: batch of data
        Returns:
            action: action to take
        """
        raise NotImplementedError() 
[docs]    def log_z(self, cond_info: Tensor) -> Tensor:
        """Compute log_z.
        Args:
            cond_info: conditional information
        Returns:
            log partition function.
        """
        raise NotImplementedError()  
[docs]class TrajectoryBalance:
    """A trajectory balance algorithm for gflownet."""
[docs]    def __init__(
        self,
        configuration: Dict[str, Any],
        environment: GraphBuildingEnv,
        context: GraphBuildingEnvContext,
        max_len: Optional[int] = None,
    ):
        """Initialize trajectory balance algorithm.
        Trajectory balance implementation, see
        "Trajectory Balance: Improved Credit Assignment in GFlowNets
            Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio"
            https://arxiv.org/abs/2201.13259.
        Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/algo/trajectory_balance.py.
        Args
            configuration: hyperparameters.
            environment: a graph environment.
            context: a context.
            rng: rng used to take random actions.
            hps: hyperparameter dictionary, see above for used keys.
                - random_action_prob: float, probability of taking a uniform random action when sampling.
                - illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions.
                - bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data.
                - tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio.
                - reward_loss_multiplier: float, multiplying constant for the bootstrap loss.
            max_len: if not None, ends trajectories of more than max_len steps.
            max_nodes: if not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
        """
        self.ctx = context
        self.env = environment
        self.hps = configuration
        self.rng = self.hps["rng"]
        self.max_len = max_len
        self.max_nodes = self.hps["max_nodes"]
        self.random_action_prob = self.hps["random_action_prob"]
        self.illegal_action_logreward = self.hps["illegal_action_logreward"]
        self.bootstrap_own_reward = self.hps["bootstrap_own_reward"]
        self.sanitize_samples = True
        self.epsilon = self.hps["tb_epsilon"]
        self.reward_loss_multiplier = self.hps["reward_loss_multiplier"]
        # Experimental flags
        self.reward_loss_is_mae = True
        self.tb_loss_is_mae = False
        self.tb_loss_is_huber = False
        self.mask_invalid_rewards = False
        self.length_normalize_losses = False
        self.sample_temp = 1 
[docs]    def _corrupt_actions(
        self, actions: List[Tuple[int, int, int]], cat: GraphActionCategorical
    ):
        """Sample from the uniform policy with probability random_action_prob.
        Args:
            actions: list of actions.
            cat: action categorical.
        """
        if self.random_action_prob <= 0:
            return
        (corrupted,) = (
            self.rng.uniform(size=len(actions)) < self.random_action_prob
        ).nonzero()
        for i in corrupted:
            n_in_batch = [int((b == i).sum()) for b in cat.batch]
            n_each = np.array(
                [
                    float(logit.shape[1]) * nb
                    for logit, nb in zip(cat.logits, n_in_batch)
                ]
            )
            which = self.rng.choice(len(n_each), p=n_each / n_each.sum())
            row = self.rng.choice(n_in_batch[which])
            col = self.rng.choice(cat.logits[which].shape[1])
            actions[i] = (which, row, col) 
[docs]    def create_training_data_from_own_samples(
        self, model: Union[nn.Module, TrajectoryBalanceModel], n: int, cond_info: Tensor
    ) -> List[Dict]:
        """Generate trajectories by sampling a model.
        Args:
            model: model used with a certain algorithm (i.e. trajectory balance). The model being sampled.
            graphs: list of N Graph endpoints.
            cond_info: conditional information, shape (N, n_info).
        Returns:
            data: a list of trajectories. Each trajectory is a dict with keys.
                - trajs: List[Tuple[Graph, GraphAction]].
                - reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise.
                - fwd_logprob: log_z + sum logprobs P_F.
                - bck_logprob: sum logprobs P_B.
                - log_z: predicted log_z.
                - loss: predicted loss (if bootstrapping).
                - is_valid: is the generated graph valid according to the environment and context.
        """
        ctx = self.ctx
        env = self.env
        dev = self.ctx.device
        cond_info = cond_info.to(dev)
        # how do we compute log_z pred?
        log_z_pred = model.log_z(cond_info)  # type: ignore
        # This will be returned as training data
        data: List[Dict] = []
        for i in range(n):
            data.append({"traj": [], "reward_pred": None, "is_valid": True})
        # Let's also keep track of trajectory statistics according to the model
        zero = torch.tensor([0], device=dev).float()
        fwd_logprob: List[List[Tensor]] = [[] for i in range(n)]
        bck_logprob: List[List[Tensor]] = [
            [zero] for i in range(n)
        ]  # zero in case there is a single invalid action
        graphs = [env.new() for i in range(n)]
        done = [False] * n
        def not_done(lst):
            return [e for i, e in enumerate(lst) if not done[i]]
        # TODO report these stats:
        mol_too_big = 0
        mol_not_sane = 0
        invalid_act = 0
        logprob_of_illegal = []
        illegal_action_logreward = torch.tensor(
            [self.illegal_action_logreward], device=dev
        )
        if self.epsilon is not None:
            epsilon = torch.tensor([self.epsilon], device=dev).float()
        for t in range(self.max_len) if self.max_len is not None else count(0):
            # Construct graphs for the trajectories that aren't yet done
            torch_graphs = [ctx.graph_to_data(i) for i in not_done(graphs)]
            not_done_mask = torch.tensor(done, device=dev).logical_not()
            # Forward pass to get GraphActionCategorical
            fwd_cat, log_reward_preds = model(
                ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]
            )
            if self.sample_temp != 1:
                sample_cat = copy.copy(fwd_cat)
                sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
                actions = sample_cat.sample()
            else:
                actions = fwd_cat.sample()
            self._corrupt_actions(actions, fwd_cat)
            graph_actions = [
                ctx.aidx_to_graph_action(g, a) for g, a in zip(torch_graphs, actions)
            ]
            log_probs = fwd_cat.log_prob(actions)
            for i, j in zip(not_done(range(n)), range(n)):
                # Step each trajectory, and accumulate statistics
                fwd_logprob[i].append(log_probs[j].unsqueeze(0))
                data[i]["traj"].append((graphs[i], graph_actions[j]))
                # Check if we're done
                if graph_actions[j].action is GraphActionType.Stop:
                    done[i] = True
                    if self.sanitize_samples and not ctx.is_sane(graphs[i]):
                        # check if the graph is sane (e.g. RDKit can
                        # construct a molecule from it) otherwise
                        # treat the done action as illegal
                        mol_not_sane += 1
                        data[i]["reward_pred"] = illegal_action_logreward.exp()
                        data[i]["is_valid"] = False
                    elif self.bootstrap_own_reward:
                        # if we're bootstrapping, extract reward prediction
                        data[i]["reward_pred"] = log_reward_preds[j].detach().exp()
                else:  # If not done, try to step the environment
                    gp = graphs[i]
                    try:
                        # env.step can raise AssertionError if the action is illegal
                        gp = env.step(graphs[i], graph_actions[j])
                        if self.max_nodes is not None:
                            assert len(gp.nodes) <= self.max_nodes
                    except AssertionError:
                        if len(gp.nodes) > self.max_nodes:
                            mol_too_big += 1
                        else:
                            invalid_act += 1
                        done[i] = True
                        data[i]["reward_pred"] = illegal_action_logreward.exp()
                        data[i]["is_valid"] = False
                        continue
                    # Add to the trajectory
                    # P_B = uniform backward
                    n_back = env.count_backward_transitions(gp)
                    bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log())
                    graphs[i] = gp
            if all(done):
                break
        for i in range(n):
            # If we're not bootstrapping, we could query the reward
            # model here, but this is expensive/impractical.  Instead
            # just report forward and backward flows
            data[i]["log_z"] = log_z_pred[i].item()
            data[i]["fwd_logprob"] = sum(fwd_logprob[i])
            data[i]["bck_logprob"] = sum(bck_logprob[i])
            if self.bootstrap_own_reward:
                if not data[i]["is_valid"]:
                    logprob_of_illegal.append(data[i]["fwd_logprob"].item())
                # If we are bootstrapping, we can report the theoretical loss as well
                numerator = data[i]["fwd_logprob"] + log_z_pred[i]
                denominator = data[i]["bck_logprob"] + data[i]["reward_pred"].log()
                if self.epsilon is not None:
                    numerator = torch.logaddexp(numerator, epsilon)
                    denominator = torch.logaddexp(denominator, epsilon)
                data[i]["loss"] = (numerator - denominator).pow(2)
        return data 
[docs]    def create_training_data_from_graphs(
        self, graphs: List[Graph]
    ) -> List[Dict[str, List[Tuple[Graph, GraphAction]]]]:
        """Generate trajectories from known endpoints.
        Args:
            graphs: list of Graph endpoints.
        Returns:
            a list of trajectories.
        """
        return [{"traj": generate_forward_trajectory(i)} for i in graphs] 
[docs]    def construct_batch(
        self,
        trajs: List[Dict[str, List[Tuple[Graph, GraphAction]]]],
        cond_info: float,
        rewards: float,
    ) -> gd.Batch:
        """Construct a batch from a list of trajectories and their information.
        Args:
            trajs: a list of N trajectories.
            cond_info: the conditional info that is considered for each trajectory. Shape (N, n_info).
            rewards: the transformed reward (e.g. R(x) ** beta) for each trajectory. Shape (N,).
        Returns:
            a (CPU) Batch object with relevant attributes added.
        """
        torch_graphs = [
            self.ctx.graph_to_data(i[0]) for tj in trajs for i in tj["traj"]
        ]
        actions = [
            self.ctx.graph_action_to_aidx(g, a)
            for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
        ]
        num_backward = torch.tensor(
            [
                # Count the number of backward transitions from s_{t+1},
                # unless t+1 = T is the last time step
                self.env.count_backward_transitions(tj["traj"][i + 1][0])
                if i + 1 < len(tj["traj"])
                else 1
                for tj in trajs
                for i in range(len(tj["traj"]))
            ]
        )
        batch = self.ctx.collate(torch_graphs)
        batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
        batch.num_backward = num_backward
        batch.actions = torch.tensor(actions)
        batch.rewards = rewards
        batch.cond_info = cond_info
        batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float()
        return batch 
[docs]    def compute_batch_losses(
        self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0
    ) -> Tuple[torch.Tensor, Dict]:
        """Compute the losses over trajectories contained in the batch.
        Args:
            model: a GNN taking in a batch of graphs as input as per constructed by self.construct_batch.
                    Must have a `log_z` attribute, itself a model, which predicts log of z(cond_info)
            batch: batch of graphs inputs as per constructed by self.construct_batch.
            num_bootstrap: the number of trajectories for which the reward loss is computed. Ignored if 0.
        Returns:
            a tuple containing the loss for each trajectory and relevant info.
        """
        dev = batch.x.device
        # A single trajectory is comprised of many graphs
        num_trajs = int(batch.traj_lens.shape[0])
        rewards = batch.rewards
        cond_info = batch.cond_info
        # This index says which trajectory each graph belongs to, so
        # it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is
        # of length 4, trajectory 1 of length 3, and so on.
        batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(
            batch.traj_lens
        )
        # The position of the last graph of each trajectory
        final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1
        # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions
        fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx])
        # Retreive the reward predictions for the full graphs,
        # i.e. the final graph of each trajectory
        log_reward_preds = log_reward_preds[final_graph_idx, 0]
        # Compute trajectory balance objective
        logz = model.log_z(cond_info)[:, 0]
        # This is the log prob of each action in the trajectory
        log_prob = fwd_cat.log_prob(batch.actions)
        # The log prob of each backward action
        log_p_B = (1 / batch.num_backward).log()
        # Take log rewards, and clip
        Rp = torch.maximum(rewards.log(), torch.tensor(-100.0, device=dev))
        # This is the log probability of each trajectory
        traj_log_prob = scatter(
            log_prob, batch_idx, dim=0, dim_size=num_trajs, reduce="sum"
        )
        # Compute log numerator and denominator of the TB objective
        numerator = logz + traj_log_prob
        denominator = Rp + scatter(
            log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum"
        )
        if self.epsilon is not None:
            # Numerical stability epsilon
            epsilon = torch.tensor([self.epsilon], device=dev).float()
            numerator = torch.logaddexp(numerator, epsilon)
            denominator = torch.logaddexp(denominator, epsilon)
        invalid_mask = 1 - batch.is_valid
        if self.mask_invalid_rewards:
            # Instead of being rude to the model and giving a
            # logreward of -100 what if we say, whatever you think the
            # logprobablity of this trajetcory is it should be smaller
            # (thus the `numerator - 1`). Why 1? Intuition?
            denominator = denominator * (1 - invalid_mask) + invalid_mask * (
                numerator.detach() - 1
            )
        if self.tb_loss_is_mae:
            traj_losses = abs(numerator - denominator)
        elif self.tb_loss_is_huber:
            pass  # TODO
        else:
            traj_losses = (numerator - denominator).pow(2)
        # Normalize losses by trajectory length
        if self.length_normalize_losses:
            traj_losses = traj_losses / batch.traj_lens
        if self.bootstrap_own_reward:
            num_bootstrap = num_bootstrap or len(rewards)
            if self.reward_loss_is_mae:
                reward_losses = abs(
                    rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap].exp()
                )
            else:
                reward_losses = (
                    rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap].exp()
                ).pow(2)
            reward_loss = reward_losses.mean()
        else:
            reward_loss = 0
        loss = traj_losses.mean() + reward_loss * self.reward_loss_multiplier
        info = {
            "offline_loss": traj_losses[: batch.num_offline].mean(),
            "online_loss": traj_losses[batch.num_offline :].mean(),
            "reward_loss": reward_loss,
            "invalid_trajectories": invalid_mask.mean() * 2,
            "invalid_logprob": (invalid_mask * traj_log_prob).sum()
            / (invalid_mask.sum() + 1e-4),
            "invalid_losses": (invalid_mask * traj_losses).sum()
            / (invalid_mask.sum() + 1e-4),
            "log_z": logz.mean(),
        }
        if not torch.isfinite(traj_losses).all():
            raise ValueError("loss is not finite")
        return loss, info