#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Model module."""
import ast
import copy
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch_geometric.data as gd
from ..dataloader.dataset import GFlowNetDataset, GFlowNetTask
from ..envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class GFlowNetAlgorithm:
    """A generic algorithm for gflownet."""
[docs]    def compute_batch_losses(
        self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0
    ) -> Tuple[float, Dict[str, float]]:
        """Computes the loss for a batch of data, and proves logging informations.
        Args:
            model: the model being trained or evaluated.
            batch: a batch of graphs.
            num_bootstrap: the number of trajectories with reward targets in the batch (if applicable).
        Returns:
            loss: the loss for that batch.
            info: logged information about model predictions.
        """
        raise NotImplementedError()  
[docs]class GFlowNetModule(pl.LightningModule):
    """Module from gflownet."""
[docs]    def __init__(
        self,
        configuration: Dict[str, Any],
        dataset: GFlowNetDataset,
        environment: GraphBuildingEnv,
        context: GraphBuildingEnvContext,
        task: GFlowNetTask,
        algorithm: GFlowNetAlgorithm,
        model: nn.Module,
    ) -> None:
        """Construct GFNModule.
        Args:
            configuration: the configuration of the module.
            dataset: the dataset to use.
            environment: the environment to use.
            context: the context to use.
            task: the task to solve.
            algorithm: algorithm (trajectory_balance or td_loss).
            model: architecture (graph_transformer_gfn or graph_transformer).
        """
        super().__init__()
        self.hps = configuration
        # self.save_hyperparameters()
        self.env = environment
        self.ctx = context
        self.dataset = dataset
        self.model = model
        self.algo = algorithm
        self.task = task
        self.test_output_path = self.hps["test_output_path"]
        self.rng = self.hps["rng"]
        self.mb_size = self.hps["global_batch_size"]
        self.clip_grad_param = self.hps["clip_grad_param"]
        self.sampling_tau = self.hps["sampling_tau"] 
[docs]    def training_step(  # type:ignore
        self,
        batch: gd.Batch,
        batch_idx: int,
        optimizer_idx: int,
        *args: Any,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        """Training step implementation.
        Args:
            batch: batch representation.
            epoch_idx: epoch index.
            batch_idx: batch index.
        Returns:
            loss and logs.
        """
        logs = dict()
        loss, info = self.algo.compute_batch_losses(
            self.model, batch, num_bootstrap=self.mb_size
        )
        logs.update(
            {
                self.model.name + f"/{k}": v.detach() if hasattr(v, "item") else v  # type: ignore
                for k, v in info.items()
            }
        )
        logs.update({"total_loss": loss.item()})  # type: ignore
        # logs for step
        _logs = {f"train/{k}": v for k, v in logs.items()}
        self.log_dict(_logs, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_loss", loss)
        # logs per epoch
        logs_epoch = {f"train_epoch/{k}": v for k, v in logs.items()}
        logs_epoch["step"] = self.current_epoch
        self.log_dict(
            logs_epoch,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
        )
        return {"loss": loss, "logs": logs} 
[docs]    def training_step_end(self, batch_parts):
        for i in self.model.parameters():
            self.clip_grad_callback(i)
        if self.sampling_tau > 0:
            for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
                b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) 
[docs]    def validation_step(  # type:ignore
        self, batch: gd.Batch, batch_idx: int, *args: Any, **kwargs: Any
    ) -> Dict[str, Any]:
        """Validation step implementation.
        Args:
            batch: batch representation.
        Returns:
            loss and logs.
        """
        loss = 0.0
        logs = dict()
        loss, info = self.algo.compute_batch_losses(
            self.model, batch, num_bootstrap=batch.num_offline
        )
        logs.update({k: v if hasattr(v, "item") else v for k, v in info.items()})
        logs.update({"total_loss": loss})
        self.log_dict(
            {f"val/{k}": v for k, v in logs.items()},
            on_step=True,
            on_epoch=True,
            prog_bar=False,
        )
        self.log("val_loss", loss)
        return {"loss": loss, "logs": logs} 
    def test_step(  # type:ignore
        self, batch: Any, batch_idx: int
    ) -> Dict[str, Any]:
        """Testing step implementation.
        Args:
            batch: batch representation.
            batch_idx: batch index, unused.
        Returns:
            loss, logs, and latent encodings.
        """
        loss = 0.0
        logs = dict()
        loss, info = self.algo.compute_batch_losses(
            self.model, batch, num_bootstrap=batch.num_offline
        )
        logs.update({k: v if hasattr(v, "item") else v for k, v in info.items()})
        logs.update({"total_loss": loss})
        self.log_dict(
            {f"test/{k}": v for k, v in logs.items()},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        self.log("test_loss", loss)
        return {"loss": loss, "logs": logs}
[docs]    def prediction_step(self, batch) -> torch.Tensor:
        """Inference step.
        Args:
            batch: batch data.
        Returns:
            output forward.
        """
        return self(batch) 
[docs]    def train_epoch_end(self, outputs: List[Dict[str, Any]]):
        """Train epoch end.
        Args:
            outputs: list of outputs epoch.
        Returns:
        """
        pass 
    # change the following to new implementation
    def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:  # type:ignore
        """Callback called at the end of an epoch on test outputs.
        Dump encodings and targets for the test set.
        Args:
            outputs: outputs for test batches.
        """
        z = {}
        targets = {}
        z_keys = [key for key in outputs[0]["z"]]
        targets_keys = [key for key in outputs[0]["targets"]]
        for key in z_keys:
            z[key] = (
                torch.cat(
                    [torch.squeeze(an_output["z"][key]) for an_output in outputs], dim=0
                )
                .detach()
                .cpu()
                .numpy()
            )
        for key in targets_keys:
            targets[key] = (
                torch.cat(
                    [torch.squeeze(an_output["targets"][key]) for an_output in outputs],
                    dim=0,
                )
                .detach()
                .cpu()
                .numpy()
            )
        pd.to_pickle(z, f"{self.test_output_path}{os.path.sep}z_build.pkl")
        pd.to_pickle(targets, f"{self.test_output_path}{os.path.sep}targets.pkl")