#!/usr/bin/env python
#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Run model saving for the GT4SD."""
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import IO, Iterable, Optional, cast
from ..algorithms.registry import ApplicationsRegistry
from ..training_pipelines import TRAINING_PIPELINE_ARGUMENTS_FOR_MODEL_SAVING
from ..training_pipelines.core import TrainingPipelineArguments
from .algorithms import (
AVAILABLE_ALGORITHMS,
AVAILABLE_ALGORITHMS_CATEGORIES,
filter_algorithm_applications,
get_configuration_tuples,
)
from .argument_parser import ArgumentParser, DataClassType
logger = logging.getLogger(__name__)
SUPPORTED_TRAINING_PIPELINES = sorted(
TRAINING_PIPELINE_ARGUMENTS_FOR_MODEL_SAVING.keys()
)
[docs]@dataclass
class SavingArguments:
"""Algorithm saving arguments."""
__name__ = "saving_base_args"
training_pipeline_name: str = field(
metadata={
"help": f"Training pipeline name, supported pipelines: {', '.join(SUPPORTED_TRAINING_PIPELINES)}."
},
)
target_version: str = field(
metadata={"help": "Target algorithm version to save."},
)
algorithm_type: Optional[str] = field(
default=None,
metadata={
"help": f"Inference algorithm type, supported types: {', '.join(AVAILABLE_ALGORITHMS_CATEGORIES['algorithm_type'])}."
},
)
domain: Optional[str] = field(
default=None,
metadata={
"help": f"Domain of the inference algorithm, supported types: {', '.join(AVAILABLE_ALGORITHMS_CATEGORIES['domain'])}."
},
)
algorithm_name: Optional[str] = field(
default=None,
metadata={"help": "Inference algorithm name."},
)
algorithm_application: Optional[str] = field(
default=None,
metadata={"help": "Inference algorithm application."},
)
source_version: Optional[str] = field(
default=None,
metadata={"help": "Source algorithm version to use for missing artifacts."},
)
[docs]class SavingArgumentParser(ArgumentParser):
"""Argument parser using a custom help logic."""
[docs] def print_help(self, file: Optional[IO[str]] = None) -> None:
"""Print help checking dynamically whether a specific pipeline is passed.
Args:
file: an optional I/O stream. Defaults to None, a.k.a., stdout and stderr.
"""
try:
help_args_set = {"-h", "--help"}
if (
len(set(sys.argv).union(help_args_set)) < len(help_args_set) + 2
): # considering filename
super().print_help()
return
args = [arg for arg in sys.argv if arg not in help_args_set]
parsed_arguments = super().parse_args_into_dataclasses(
args=args, return_remaining_strings=True
)
trainer_arguments = None
for arguments in parsed_arguments:
if arguments.__name__ == "trainer_base_args":
trainer_arguments = arguments
break
if trainer_arguments:
trainer_arguments.training_pipeline_name
training_pipeline_arguments = (
TRAINING_PIPELINE_ARGUMENTS_FOR_MODEL_SAVING.get(
trainer_arguments.training_pipeline_name,
TrainingPipelineArguments,
)
)
parser = ArgumentParser(
tuple(
[SavingArguments, *training_pipeline_arguments] # type:ignore
)
)
parser.print_help()
except Exception:
super().print_help()
[docs]def main() -> None:
"""
Run an algorithm saving pipeline.
Raises:
ValueError: in case the provided training pipeline provided is not supported.
"""
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
base_args = SavingArgumentParser(
cast(DataClassType, SavingArguments)
).parse_args_into_dataclasses(return_remaining_strings=True)[0]
training_pipeline_name = base_args.training_pipeline_name
if training_pipeline_name not in set(SUPPORTED_TRAINING_PIPELINES):
ValueError(
f"Training pipeline {training_pipeline_name} is not supported. Supported types: {', '.join(SUPPORTED_TRAINING_PIPELINES)}."
)
training_pipeline_saving_arguments = TRAINING_PIPELINE_ARGUMENTS_FOR_MODEL_SAVING[
training_pipeline_name
]
parser = SavingArgumentParser(
cast(
Iterable[DataClassType],
tuple([SavingArguments, training_pipeline_saving_arguments]),
)
)
saving_args, training_pipeline_saving_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
filters = {
key: saving_args.__dict__[key]
for key in [
"algorithm_type",
"algorithm_application",
"domain",
"algorithm_name",
"source_version",
]
}
configuration_tuples = get_configuration_tuples(
filter_algorithm_applications(algorithms=AVAILABLE_ALGORITHMS, filters=filters)
)
if len(configuration_tuples) > 1:
logger.info(
f"Multiple configurations matching the parameters:{os.linesep}"
f"{os.linesep.join(map(str, configuration_tuples))}{os.linesep}"
f"Select one by specifying additional algorithms parameters: {','.join('--' + key for key, value in filters.items() if not value)}.",
)
return
elif len(configuration_tuples) < 1:
provided_filters = {key: value for key, value in filters.items() if value}
logger.error(
"No configurations matching the provided parameters, "
f"please review the supported configurations:{os.linesep}"
f"{os.linesep.join(map(str, configuration_tuples))}{os.linesep}"
f"Please review the parameters provided:{os.linesep}"
f"{provided_filters}"
)
configuration_tuple = configuration_tuples[0]
logger.info(f"Selected configuration: {configuration_tuple}")
algorithm_application = ApplicationsRegistry.applications[configuration_tuple]
configuration_class = algorithm_application.configuration_class
logger.info(
f'Saving model version "{saving_args.target_version}" with the following configuration: {configuration_class}'
)
configuration_class.save_version_from_training_pipeline_arguments(
training_pipeline_arguments=training_pipeline_saving_args,
target_version=saving_args.target_version,
source_version=saving_args.source_version,
)
if __name__ == "__main__":
main()