#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Data module for granular."""
import logging
from typing import Callable, List, Optional, cast
import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Sampler, Subset, random_split
from .dataset import CombinedGranularDataset, GranularDataset
from .sampler import StratifiedSampler
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class GranularDataModule(pl.LightningDataModule):
    """Data module from granular."""
[docs]    def __init__(
        self,
        dataset_list: List[GranularDataset],
        validation_split: Optional[float] = None,
        validation_indices_file: Optional[str] = None,
        stratified_batch_file: Optional[str] = None,
        stratified_value_name: Optional[str] = None,
        batch_size: int = 64,
        num_workers: int = 1,
    ) -> None:
        """Construct GranularDataModule.
        Args:
            dataset_list: a list of granular datasets.
            validation_split: proportion used for validation. Defaults to None,
                a.k.a., use indices file if provided otherwise uses half of the data for validation.
            validation_indices_file: indices to use for validation. Defaults to None, a.k.a.,
                use validation split proportion, if not provided uses half of the data for validation.
            stratified_batch_file: stratified batch file for sampling. Defaults to None, a.k.a.,
                no stratified sampling.
            stratified_value_name: stratified value name. Defaults to None, a.k.a.,
                no stratified sampling. Needed in case a stratified batch file is provided.
            batch_size: batch size. Defaults to 64.
            num_workers: number of workers. Defaults to 1.
        """
        super().__init__()
        self.batch_size = batch_size
        self.validation_split = validation_split
        self.validation_indices_file = validation_indices_file
        self.dataset_list = dataset_list
        self.num_workers = num_workers
        self.stratified_batch_file = stratified_batch_file
        self.stratified_value_name = stratified_value_name
        self.prepare_train_data() 
[docs]    @staticmethod
    def combine_datasets(
        dataset_list: List[GranularDataset],
    ) -> CombinedGranularDataset:
        """Combine granular datasets.
        Args:
            dataset_list: a list of granular datasets.
        Returns:
            a combined granular dataset.
        """
        return CombinedGranularDataset(
            [a_dataset.dataset for a_dataset in dataset_list]
        ) 
[docs]    def prepare_train_data(self) -> None:
        """Prepare training dataset."""
        self.train_dataset = GranularDataModule.combine_datasets(self.dataset_list) 
[docs]    def prepare_test_data(self, dataset_list: List[GranularDataset]) -> None:
        """Prepare testing dataset.
        Args:
            dataset_list: a list of granular datasets.
        """
        self.test_dataset = GranularDataModule.combine_datasets(dataset_list) 
[docs]    def setup(self, stage: Optional[str] = None) -> None:
        """Setup the data module.
        Args:
            stage: stage considered, unused. Defaults to None.
        """
        if (
            self.stratified_batch_file is not None
            and self.stratified_value_name is None
        ):
            raise ValueError(
                f"stratified_batch_file={self.stratified_batch_file}, need to set stratified_value_name as well"
            )
        if self.validation_indices_file is None and self.validation_split is None:
            self.validation_split = 0.5
        if self.validation_indices_file:
            val_indices = (
                pd.read_csv(self.validation_indices_file).values.flatten().tolist()
            )
            train_indices = [
                i for i in range(len(self.train_dataset)) if i not in val_indices
            ]
            self.train_data = Subset(self.train_dataset, train_indices)
            self.val_data = Subset(self.train_dataset, val_indices)
        else:
            val = int(len(self.train_dataset) * cast(float, (self.validation_split)))
            train = len(self.train_dataset) - val
            self.train_data, self.val_data = random_split(
                self.train_dataset, [train, val]
            )
        logger.info(f"number of data points used for training: {len(self.train_data)}")
        logger.info(f"number of data points used for validation: {len(self.val_data)}")
        logger.info(
            f"validation proportion: {len(self.val_data) / (len(self.val_data) + len(self.train_data))}"
        ) 
[docs]    @staticmethod
    def get_stratified_batch_sampler(
        stratified_batch_file: str,
        stratified_value_name: str,
        batch_size: int,
        selector_fn: Callable[[pd.DataFrame], pd.DataFrame],
    ) -> StratifiedSampler:
        """Get stratified batch sampler.
        Args:
            stratified_batch_file: stratified batch file for sampling.
            stratified_value_name: stratified value name.
            batch_size: batch size.
            selector_fn: selector function for stratified sampling.
        Returns:
            a stratified batch sampler.
        """
        stratified_batch_dataframe = pd.read_csv(stratified_batch_file)
        stratified_data = stratified_batch_dataframe[
            selector_fn(stratified_batch_dataframe)
        ][stratified_value_name].values
        stratified_data_tensor = torch.from_numpy(stratified_data)
        return StratifiedSampler(targets=stratified_data_tensor, batch_size=batch_size) 
[docs]    def train_dataloader(self) -> DataLoader:
        """Get a training data loader.
        Returns:
            a training data loader.
        """
        sampler: Optional[Sampler] = None
        if self.stratified_batch_file:
            sampler = GranularDataModule.get_stratified_batch_sampler(
                stratified_batch_file=self.stratified_batch_file,
                stratified_value_name=str(self.stratified_value_name),
                batch_size=self.batch_size,
                selector_fn=lambda dataframe: ~dataframe["validation"],
            )
        return DataLoader(
            self.train_data,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            pin_memory=False,
            sampler=sampler,
        ) 
[docs]    def val_dataloader(self) -> DataLoader:
        """Get a validation data loader.
        Returns:
            a validation data loader.
        """
        sampler: Optional[Sampler] = None
        if self.stratified_batch_file:
            sampler = GranularDataModule.get_stratified_batch_sampler(
                stratified_batch_file=self.stratified_batch_file,
                stratified_value_name=str(self.stratified_value_name),
                batch_size=self.batch_size,
                selector_fn=lambda dataframe: dataframe["validation"],
            )
        return DataLoader(
            self.val_data,
            num_workers=self.num_workers,
            batch_size=self.batch_size,
            pin_memory=False,
            sampler=sampler,
        ) 
    def test_dataloader(self) -> DataLoader:
        """Get a testing data loader.
        Returns:
            a testing data loader.
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=False,
        )