#!/usr/bin/env python
#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Run training pipelines for the GT4SD."""
import logging
import sys
from dataclasses import dataclass, field
from typing import IO, Iterable, Optional, Tuple, cast
from ..configuration import GT4SDConfiguration
from ..training_pipelines import (
    TRAINING_PIPELINE_ARGUMENTS_MAPPING,
    TRAINING_PIPELINE_MAPPING,
)
from ..training_pipelines.core import TrainingPipelineArguments
from .argument_parser import ArgumentParser, DataClass, DataClassType
logger = logging.getLogger(__name__)
SUPPORTED_TRAINING_PIPELINES = sorted(
    list(set(TRAINING_PIPELINE_ARGUMENTS_MAPPING) & set(TRAINING_PIPELINE_MAPPING))
)
# disable cudnn if issues with gpu training
if GT4SDConfiguration.get_instance().gt4sd_disable_cudnn:
    import torch
    torch.backends.cudnn.enabled = False
[docs]@dataclass
class TrainerArguments:
    """Trainer arguments."""
    __name__ = "trainer_base_args"
    training_pipeline_name: str = field(
        metadata={
            "help": f"Training pipeline name, supported pipelines: {', '.join(SUPPORTED_TRAINING_PIPELINES)}."
        },
    )
    configuration_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "Configuration file for the training in JSON format. It can be used to completely by-pass pipeline specific arguments."
        },
    ) 
[docs]class TrainerArgumentParser(ArgumentParser):
    """Argument parser using a custom help logic."""
[docs]    def print_help(self, file: Optional[IO[str]] = None) -> None:
        """Print help checking dynamically whether a specific pipeline is passed.
        Args:
            file: an optional I/O stream. Defaults to None, a.k.a., stdout and stderr.
        """
        try:
            help_args_set = {"-h", "--help"}
            if (
                len(set(sys.argv).union(help_args_set)) < len(help_args_set) + 2
            ):  # considering filename
                super().print_help()
                return
            args = [arg for arg in sys.argv if arg not in help_args_set]
            parsed_arguments = super().parse_args_into_dataclasses(
                args=args, return_remaining_strings=True
            )
            trainer_arguments = None
            for arguments in parsed_arguments:
                if arguments.__name__ == "trainer_base_args":
                    trainer_arguments = arguments
                    break
            if trainer_arguments:
                trainer_arguments.training_pipeline_name
                training_pipeline_arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING.get(
                    trainer_arguments.training_pipeline_name, TrainingPipelineArguments
                )
                parser = ArgumentParser(
                    tuple(
                        [TrainerArguments, *training_pipeline_arguments]  # type:ignore
                    )
                )
                parser.print_help()
        except Exception:
            super().print_help() 
[docs]    def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]:  # type: ignore
        """Overriding default .json parser.
        It by-passes all command line arguments and simply add the training pipeline.
        Args:
            json_file: JSON file containing pipeline configuration parameters.
        Returns:
            parsed arguments in a tuple of dataclasses.
        """
        number_of_dataclass_types = len(self.dataclass_types)  # type:ignore
        self.dataclass_types = [
            dataclass_type
            for dataclass_type in self.dataclass_types  # type:ignore
            if "gt4sd.cli.trainer.TrainerArguments" not in str(dataclass_type)
        ]
        try:
            parsed_arguments = super().parse_json_file(  # type:ignore
                json_file=json_file, allow_extra_keys=True
            )
        except Exception:
            logger.exception(
                f"error parsing configuration file: {json_file}, printing error and exiting"
            )
            sys.exit(1)
        if number_of_dataclass_types > len(self.dataclass_types):
            self.dataclass_types.insert(0, cast(DataClassType, TrainerArguments))
        return parsed_arguments  
[docs]def main() -> None:
    """
    Run a training pipeline.
    Raises:
        ValueError: in case the provided training pipeline provided is not supported.
    """
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    base_args = TrainerArgumentParser(
        cast(DataClassType, TrainerArguments)
    ).parse_args_into_dataclasses(return_remaining_strings=True)[0]
    training_pipeline_name = base_args.training_pipeline_name
    if training_pipeline_name not in set(SUPPORTED_TRAINING_PIPELINES):
        ValueError(
            f"Training pipeline {training_pipeline_name} is not supported. Supported types: {', '.join(SUPPORTED_TRAINING_PIPELINES)}."
        )
    arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING[training_pipeline_name]
    parser = TrainerArgumentParser(
        cast(Iterable[DataClassType], tuple([TrainerArguments, *arguments]))
    )
    configuration_filepath = base_args.configuration_file
    if configuration_filepath:
        args = parser.parse_json_file(json_file=configuration_filepath)
    else:
        args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
    config = {
        arg.__name__: arg.__dict__
        for arg in args
        if isinstance(arg, TrainingPipelineArguments) and isinstance(arg.__name__, str)
    }
    if (
        base_args.training_pipeline_name == "granular-trainer"
        and config["model_args"]["model_list_path"] is None
    ):
        config["model_args"]["model_list_path"] = configuration_filepath
    pipeline = TRAINING_PIPELINE_MAPPING[training_pipeline_name]
    pipeline().train(**config) 
if __name__ == "__main__":
    main()