Source code for gt4sd.algorithms.registry

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Collection of available methods."""


import logging
from dataclasses import dataclass as vanilla_dataclass
from dataclasses import field, make_dataclass
from functools import WRAPPER_ASSIGNMENTS, update_wrapper
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    NamedTuple,
    Optional,
    Type,
    TypeVar,
)

import pydantic

# pyright (pylance in VSCode) does not support pydantic typechecking
# if typing.TYPE_CHECKING:
#     from dataclasses import dataclass
# else:
#     from pydantic.dataclasses import dataclass
from pydantic.dataclasses import dataclass

from ..exceptions import DuplicateApplicationRegistration
from .core import AlgorithmConfiguration, GeneratorAlgorithm

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class ConfigurationTuple(NamedTuple): """Attributes to uniquely describe an AlgorithmConfiguration.""" algorithm_type: str domain: str algorithm_name: str algorithm_application: str
[docs]class AnnotationTuple(NamedTuple): annotation: type default_value: Any # TODO serializable type?
[docs]@vanilla_dataclass class AlgorithmApplication: """Collect all needed to run an application.""" algorithm_class: Type[GeneratorAlgorithm] configuration_class: Type[AlgorithmConfiguration] parameters_dict: Dict[str, AnnotationTuple] = field(default_factory=dict)
# includes algorithm_version: str
[docs]class RegistryDict(Dict[ConfigurationTuple, AlgorithmApplication]): """Dict that raises when reassigning an existing key."""
[docs] def __setitem__(self, key, value): if self.__contains__(key): raise DuplicateApplicationRegistration( title="Applications exists", detail=f"key {key} was already registered and would override another application.", ) # if it's really needed for some reason, delete the item first, then add it. else: super().__setitem__(key, value)
[docs]class ApplicationsRegistry: """Registry to collect "applications" and make them accessible. An application denotes the combination of an :class:`AlgorithmConfiguration<gt4sd.algorithms.core.AlgorithmConfiguration>` and a :class:`GeneratorAlgorithm<gt4sd.algorithms.core.GeneratorAlgorithm>`. """ # NOTE on import of registy also ensure import of modules to populate applications applications: RegistryDict = RegistryDict()
[docs] @classmethod def _register_application( cls, algorithm_class: Type[GeneratorAlgorithm], algorithm_configuration_class: Type[AlgorithmConfiguration], ): # testing that configuration class is callable without arguments try: algorithm_configuration_class() except pydantic.ValidationError as e: logger.exception(e) config_tuple = cls.configuration_class_as_tuple(algorithm_configuration_class) cls.applications[config_tuple] = AlgorithmApplication( algorithm_class=algorithm_class, configuration_class=algorithm_configuration_class, )
[docs] @classmethod def register_algorithm_application( cls, algorithm_class: Type[GeneratorAlgorithm], as_algorithm_application: Optional[str] = None, ) -> Callable[[Type[AlgorithmConfiguration]], Type[AlgorithmConfiguration]]: """Complete and register a configuration via decoration. Args: algorithm_class: The algorithm that uses the configuration. as_algorithm_application: Optional application name to use instead of the configurations class name. Returns: A function to complete the configuration class' attributes to reflect the matching GeneratorAlgorithm and application. The final class is registered and returned. Example: as decorator:: from gt4sd.algorithms.registry import ApplicationsRegistry @ApplicationsRegistry.register_algorithm_application(SomeAlgorithm) class SomeApplication(AlgorithmConfiguration): algorithm_type: ClassVar[str] = 'conditional_generation' domain: ClassVar[str] = 'materials' algorithm_version: str = 'v0' some_more_serializable_arguments_with_defaults: int = 42 Example: directly, here for an additional algorithm application with the same algorithm:: AnotherApplication = ApplicationsRegistry.register_algorithm_application( SomeAlgorithm, 'AnotherApplication' )(SomeApplication) """ def decorator( configuration_class: Type[AlgorithmConfiguration], ) -> Type[AlgorithmConfiguration]: """Complete the configuration class' attributes and register the class. Args: configuration_class: class to complete. Returns: a completed class. """ VanillaConfiguration = make_dataclass( cls_name=configuration_class.__name__, # call `@dataclass` for users to avoid confusion bases=(vanilla_dataclass(configuration_class),), fields=[ ( "algorithm_name", # type: ignore ClassVar[str], field(default=algorithm_class.__name__), # type: ignore ), ( "algorithm_application", # type: ignore ClassVar[str], field( default=( as_algorithm_application or configuration_class.__name__ # type: ignore ) ), ), ], # type: ignore ) # NOTE: Needed to circumvent a pydantic TypeError: Parameter list to Generic[...] cannot be empty VanillaConfiguration.__parameters__ = (TypeVar("T"),) # type: ignore # NOTE: Duplicate call necessary for pydantic >=1.10.* - see https://github.com/pydantic/pydantic/issues/4695 PydanticConfiguration: Type[AlgorithmConfiguration] = dataclass( # type: ignore VanillaConfiguration ) PydanticConfiguration: Type[AlgorithmConfiguration] = dataclass( # type: ignore VanillaConfiguration ) # get missing entries missing_in__dict__ = [ key for key in configuration_class.__dict__ if key not in PydanticConfiguration.__dict__ ] update_wrapper( wrapper=PydanticConfiguration, wrapped=configuration_class, assigned=missing_in__dict__ + list(WRAPPER_ASSIGNMENTS), updated=(), # default of '__dict__' does not apply here, see missing_in__dict__ ) cls._register_application(algorithm_class, PydanticConfiguration) return PydanticConfiguration return decorator
[docs] @staticmethod def configuration_class_as_tuple( algorithm_configuration_class: Type[AlgorithmConfiguration], ) -> "ConfigurationTuple": """Get a hashable identifier per application.""" return ConfigurationTuple( algorithm_type=algorithm_configuration_class.algorithm_type, domain=algorithm_configuration_class.domain, algorithm_name=algorithm_configuration_class.algorithm_name, algorithm_application=algorithm_configuration_class.algorithm_application, )
[docs] @classmethod def get_application( cls, algorithm_type: str, domain: str, algorithm_name: str, algorithm_application: str, ) -> AlgorithmApplication: return cls.applications[ ConfigurationTuple( algorithm_type=algorithm_type, domain=domain, algorithm_name=algorithm_name, algorithm_application=algorithm_application, ) ]
[docs] @classmethod def get_matching_configuration_defaults( cls, algorithm_type: str, domain: str, algorithm_name: str, algorithm_application: str, ) -> Dict[str, AnnotationTuple]: Configuration = cls.get_application( algorithm_type=algorithm_type, domain=domain, algorithm_name=algorithm_name, algorithm_application=algorithm_application, ).configuration_class defaults_dict = {} for ( argument, default_value, ) in Configuration().__dict__.items(): defaults_dict[argument] = AnnotationTuple( annotation=Configuration.__annotations__[argument], default_value=default_value, ) return defaults_dict
[docs] @classmethod def get_matching_configuration_schema( cls, algorithm_type: str, domain: str, algorithm_name: str, algorithm_application: str, ) -> Dict[str, Any]: Configuration = cls.get_application( algorithm_type=algorithm_type, domain=domain, algorithm_name=algorithm_name, algorithm_application=algorithm_application, ).configuration_class return Configuration.__pydantic_model__.schema() # type: ignore
[docs] @classmethod def get_configuration_instance( cls, algorithm_type: str, domain: str, algorithm_name: str, algorithm_application: str, *args, **kwargs, ) -> AlgorithmConfiguration: """Create an instance of the matching AlgorithmConfiguration from the ApplicationsRegistry. Args: algorithm_type: general type of generative algorithm. domain: general application domain. Hints at input/output types. algorithm_name: name of the algorithm to use with this configuration. algorithm_application: unique name for the application that is the use of this configuration together with a specific algorithm. algorithm_version: to differentiate between different versions of an application. *args: additional positional arguments passed to the configuration. **kwargs: additional keyword arguments passed to the configuration. Returns: an instance of the configuration. """ Configuration = cls.get_application( algorithm_type=algorithm_type, domain=domain, algorithm_name=algorithm_name, algorithm_application=algorithm_application, ).configuration_class return Configuration(*args, **kwargs)
[docs] @classmethod def get_application_instance( cls, algorithm_type: str, domain: str, algorithm_name: str, algorithm_application: str, target: Any = None, **kwargs, ) -> GeneratorAlgorithm: """Instantiate an algorithm via a matching application from the ApplicationsRegistry. Additional arguments are passed to the configuration and override any arguments in the ApplicationsRegistry. Args: algorithm_type: general type of generative algorithm. domain: general application domain. Hints at input/output types. algorithm_name: name of the algorithm to use with this configuration. algorithm_application: unique name for the application that is the use of this configuration together with a specific algorithm. algorithm_version: to differentiate between different versions of an application. target: optional context or condition for the generation. **kwargs: additional keyword arguments passed to the configuration. Returns: an instance of a generative algorithm ready to sample from. """ application_tuple = cls.get_application( algorithm_type=algorithm_type, domain=domain, algorithm_name=algorithm_name, algorithm_application=algorithm_application, ) parameters = { key: annotation_tuple.default_value for key, annotation_tuple in application_tuple.parameters_dict.items() } parameters.update(kwargs) return application_tuple.algorithm_class( configuration=application_tuple.configuration_class(**parameters), target=target, )
[docs] @classmethod def list_available(cls) -> List[Dict[str, str]]: available = [] for config_tuple, application in cls.applications.items(): available.extend( [ dict(**config_tuple._asdict(), algorithm_version=version) for version in application.configuration_class.list_versions() ] ) return available