Source code for gt4sd.cli.trainer

#!/usr/bin/env python
#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

"""Run training pipelines for the GT4SD."""


import logging
import sys
from dataclasses import dataclass, field
from typing import IO, Iterable, Optional, Tuple, cast

from ..configuration import GT4SDConfiguration
from ..training_pipelines import (
    TRAINING_PIPELINE_ARGUMENTS_MAPPING,
    TRAINING_PIPELINE_MAPPING,
)
from ..training_pipelines.core import TrainingPipelineArguments
from .argument_parser import ArgumentParser, DataClass, DataClassType

logger = logging.getLogger(__name__)

SUPPORTED_TRAINING_PIPELINES = sorted(
    list(set(TRAINING_PIPELINE_ARGUMENTS_MAPPING) & set(TRAINING_PIPELINE_MAPPING))
)

# disable cudnn if issues with gpu training
if GT4SDConfiguration.get_instance().gt4sd_disable_cudnn:
    import torch

    torch.backends.cudnn.enabled = False


[docs]@dataclass class TrainerArguments: """Trainer arguments.""" __name__ = "trainer_base_args" training_pipeline_name: str = field( metadata={ "help": f"Training pipeline name, supported pipelines: {', '.join(SUPPORTED_TRAINING_PIPELINES)}." }, ) configuration_file: Optional[str] = field( default=None, metadata={ "help": "Configuration file for the training in JSON format. It can be used to completely by-pass pipeline specific arguments." }, )
[docs]class TrainerArgumentParser(ArgumentParser): """Argument parser using a custom help logic."""
[docs] def print_help(self, file: Optional[IO[str]] = None) -> None: """Print help checking dynamically whether a specific pipeline is passed. Args: file: an optional I/O stream. Defaults to None, a.k.a., stdout and stderr. """ try: help_args_set = {"-h", "--help"} if ( len(set(sys.argv).union(help_args_set)) < len(help_args_set) + 2 ): # considering filename super().print_help() return args = [arg for arg in sys.argv if arg not in help_args_set] parsed_arguments = super().parse_args_into_dataclasses( args=args, return_remaining_strings=True ) trainer_arguments = None for arguments in parsed_arguments: if arguments.__name__ == "trainer_base_args": trainer_arguments = arguments break if trainer_arguments: trainer_arguments.training_pipeline_name training_pipeline_arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING.get( trainer_arguments.training_pipeline_name, TrainingPipelineArguments ) parser = ArgumentParser( tuple( [TrainerArguments, *training_pipeline_arguments] # type:ignore ) ) parser.print_help() except Exception: super().print_help()
[docs] def parse_json_file(self, json_file: str) -> Tuple[DataClass, ...]: # type: ignore """Overriding default .json parser. It by-passes all command line arguments and simply add the training pipeline. Args: json_file: JSON file containing pipeline configuration parameters. Returns: parsed arguments in a tuple of dataclasses. """ number_of_dataclass_types = len(self.dataclass_types) # type:ignore self.dataclass_types = [ dataclass_type for dataclass_type in self.dataclass_types # type:ignore if "gt4sd.cli.trainer.TrainerArguments" not in str(dataclass_type) ] try: parsed_arguments = super().parse_json_file( # type:ignore json_file=json_file, allow_extra_keys=True ) except Exception: logger.exception( f"error parsing configuration file: {json_file}, printing error and exiting" ) sys.exit(1) if number_of_dataclass_types > len(self.dataclass_types): self.dataclass_types.insert(0, cast(DataClassType, TrainerArguments)) return parsed_arguments
[docs]def main() -> None: """ Run a training pipeline. Raises: ValueError: in case the provided training pipeline provided is not supported. """ logging.basicConfig(stream=sys.stdout, level=logging.INFO) base_args = TrainerArgumentParser( cast(DataClassType, TrainerArguments) ).parse_args_into_dataclasses(return_remaining_strings=True)[0] training_pipeline_name = base_args.training_pipeline_name if training_pipeline_name not in set(SUPPORTED_TRAINING_PIPELINES): ValueError( f"Training pipeline {training_pipeline_name} is not supported. Supported types: {', '.join(SUPPORTED_TRAINING_PIPELINES)}." ) arguments = TRAINING_PIPELINE_ARGUMENTS_MAPPING[training_pipeline_name] parser = TrainerArgumentParser( cast(Iterable[DataClassType], tuple([TrainerArguments, *arguments])) ) configuration_filepath = base_args.configuration_file if configuration_filepath: args = parser.parse_json_file(json_file=configuration_filepath) else: args = parser.parse_args_into_dataclasses(return_remaining_strings=True) config = { arg.__name__: arg.__dict__ for arg in args if isinstance(arg, TrainingPipelineArguments) and isinstance(arg.__name__, str) } if ( base_args.training_pipeline_name == "granular-trainer" and config["model_args"]["model_list_path"] is None ): config["model_args"]["model_list_path"] = configuration_filepath pipeline = TRAINING_PIPELINE_MAPPING[training_pipeline_name] pipeline().train(**config)
if __name__ == "__main__": main()