#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Model combiner module."""
import os
from typing import Any, Callable, Dict, List, Tuple, cast
import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
from .models import GranularBaseModel, GranularEncoderDecoderModel
from .models.model_builder import building_models, define_latent_models_input_size
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch
[docs]class GranularModule(pl.LightningModule):
    """Module from granular."""
[docs]    def __init__(
        self,
        architecture_autoencoders: List[Dict[str, Any]],
        architecture_latent_models: List[Dict[str, Any]],
        lr: float = 1e-4,
        test_output_path: str = "./test",
        **kwargs,
    ) -> None:
        """Construct GranularModule.
        Args:
            architecture_autoencoders: list of autoencoder architecture configurations.
            architecture_latent_models: list of latent model architecture configurations.
            lr: learning rate for Adam optimizer. Defaults to 1e-4.
            test_output_path: path where to save latent encodings and predictions for the test set
                when an epoch ends. Defaults to a a folder called "test" in the current working directory.
        """
        super().__init__()
        self.save_hyperparameters()
        architecture_latent_models = define_latent_models_input_size(
            architecture_autoencoders, architecture_latent_models
        )
        self.architecture_autoencoders = architecture_autoencoders
        self.architecture_latent_models = architecture_latent_models
        self.autoencoders = building_models(self.architecture_autoencoders)
        self.latent_models = building_models(self.architecture_latent_models)
        self.lr = lr
        self.test_output_path = test_output_path
        for model in self.autoencoders + self.latent_models:
            setattr(self, model.name, model) 
[docs]    def _autoencoder_step(
        self, batch: Any, model: GranularEncoderDecoderModel, model_step_fn: Callable
    ) -> Tuple[Any, Any, Any]:
        """Autoencoder module forward pass.
        Args:
            batch: batch representation.
            model: a module.
            model_step_fn: callable for the step.
        Returns:
            a tuple containing the latent representation, the loss and the logs for the module.
        """
        return model_step_fn(
            input_data=batch[model.input_key],
            target_data=batch[model.target_key],
            device=self.device,
            current_epoch=self.current_epoch,
        ) 
[docs]    def _latent_step(
        self,
        batch: Any,
        model: GranularBaseModel,
        model_step_fn: Callable,
        z: Dict[int, Any],
    ) -> Tuple[Any, Any, Any]:
        """Latent module forward pass.
        Args:
            batch: batch representation.
            model: a module.
            model_step_fn: callable for the step.
            z: latent encodings.
        Returns:
            a tuple containing the latent step ouput, the loss and the logs for the module.
        """
        z_model_input = torch.cat(
            [
                torch.squeeze(z[pos]) if len(z[pos].size()) == 3 else z[pos]
                for pos in model.from_position
            ],
            dim=1,
        )
        return model_step_fn(
            input_data=z_model_input,
            target_data=batch[model.target_key],
            device=self.device,
            current_epoch=self.current_epoch,
        ) 
[docs]    def training_step(  # type:ignore
        self, batch: Any, *args, **kwargs
    ) -> Dict[str, Any]:
        """Training step implementation.
        Args:
            batch: batch representation.
        Returns:
            loss and logs.
        """
        loss = 0.0
        z = dict()
        logs = dict()
        for model in self.autoencoders:
            z[model.position], loss_model, logs_model = self._autoencoder_step(
                batch=batch,
                model=cast(GranularEncoderDecoderModel, model),
                model_step_fn=model.step,
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        for model in self.latent_models:
            _, loss_model, logs_model = self._latent_step(
                batch=batch, model=model, model_step_fn=model.step, z=z
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        logs.update({"total_loss": loss})
        self.log_dict(
            {f"train/{k}": v for k, v in logs.items()}, on_epoch=False, prog_bar=False
        )
        logs_epoch = {f"train_epoch/{k}": v for k, v in logs.items()}
        logs_epoch["step"] = self.current_epoch
        self.log_dict(logs_epoch, on_step=False, on_epoch=True, prog_bar=False)
        return {"loss": loss, "logs": logs} 
[docs]    def validation_step(  # type:ignore
        self, batch: Any, *args, **kwargs
    ) -> Dict[str, Any]:
        """Validation step implementation.
        Args:
            batch: batch representation.
        Returns:
            loss and logs.
        """
        loss = 0.0
        z = dict()
        logs = dict()
        for model in self.autoencoders:
            z[model.position], loss_model, logs_model = self._autoencoder_step(
                batch=batch,
                model=cast(GranularEncoderDecoderModel, model),
                model_step_fn=model.val_step,
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        for model in self.latent_models:
            _, loss_model, logs_model = self._latent_step(
                batch=batch, model=model, model_step_fn=model.val_step, z=z
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        logs.update({"total_loss": loss})
        self.log_dict(
            {f"val/{k}": v for k, v in logs.items()}, on_epoch=True, prog_bar=True
        )
        return {"loss": loss, "logs": logs} 
    def test_step(  # type:ignore
        self, batch: Any, batch_idx: int, *args, **kwargs
    ) -> Dict[str, Any]:
        """Testing step implementation.
        Args:
            batch: batch representation.
            batch_idx: batch index, unused.
        Returns:
            loss, logs, and latent encodings.
        """
        loss = 0.0
        z = dict()
        logs = dict()
        for model in self.autoencoders:
            z[model.position], loss_model, logs_model = self._autoencoder_step(
                batch=batch,
                model=cast(GranularEncoderDecoderModel, model),
                model_step_fn=model.val_step,
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        for model in self.latent_models:
            _, loss_model, logs_model = self._latent_step(
                batch=batch, model=model, model_step_fn=model.val_step, z=z
            )
            logs.update({model.name + f"/{k}": v for k, v in logs_model.items()})
            loss += loss_model
        logs.update({"total_loss": loss})
        self.log_dict(
            {f"val/{k}": v for k, v in logs.items()}, on_epoch=True, prog_bar=True
        )
        return {"loss": loss, "logs": logs, "z": z}
    def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:  # type:ignore
        """Callback called at the end of an epoch on test outputs.
        Dump encodings and targets for the test set.
        Args:
            outputs: outputs for test batches.
        """
        z = {}
        targets = {}
        z_keys = [key for key in outputs[0]["z"]]
        targets_keys = [key for key in outputs[0]["targets"]]
        for key in z_keys:
            z[key] = (
                torch.cat(
                    [torch.squeeze(an_output["z"][key]) for an_output in outputs], dim=0
                )
                .detach()
                .cpu()
                .numpy()
            )
        for key in targets_keys:
            targets[key] = (
                torch.cat(
                    [torch.squeeze(an_output["targets"][key]) for an_output in outputs],
                    dim=0,
                )
                .detach()
                .cpu()
                .numpy()
            )
        pd.to_pickle(z, f"{self.test_output_path}{os.path.sep}z_build.pkl")
        pd.to_pickle(targets, f"{self.test_output_path}{os.path.sep}targets.pkl")