#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Model builder module."""
import logging
from collections import OrderedDict
from typing import Any, Dict, List
from typing import OrderedDict as OrderedDictType
import torch
from ....torch import device_claim
from . import ARCHITECTURE_FACTORY
from .base_model import GranularBaseModel
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]def build_model(architecture: Dict[str, Any]) -> GranularBaseModel:
"""Build model from architecture configuration.
Args:
architecture: architecture configuration.
Returns:
built model.
"""
model_name = architecture["name"]
model_type = architecture["type"].lower()
if model_type not in ARCHITECTURE_FACTORY:
raise ValueError(
f"model_type={model_type} not supported. Pick a valid one: {sorted(ARCHITECTURE_FACTORY.keys())}"
)
model = ARCHITECTURE_FACTORY[model_type](
data=architecture["data"], **architecture["hparams"]
)
if architecture["start_from_checkpoint"]:
loaded_params = torch.load(
architecture["hparams"]["checkpoint_path"], map_location=device_claim(None)
)
loaded_architecture_latent = loaded_params["hyper_parameters"][
"architecture_latent_models"
]
loaded_architecture_autoencoder = loaded_params["hyper_parameters"][
"architecture_autoencoders"
]
for arcihtecture_autoencoder in loaded_architecture_autoencoder:
if model_name == arcihtecture_autoencoder["name"]:
architecture = arcihtecture_autoencoder
for architecture_latent in loaded_architecture_latent:
if model_name == architecture_latent["name"]:
architecture = architecture_latent
loaded_state_dict: OrderedDictType[str, torch.Tensor] = OrderedDict()
for layer_name in loaded_params["state_dict"]:
state_model_name, *layer_name_elements = layer_name.split(".")
state_name = ".".join(layer_name_elements)
try:
checkpoint_model_name = architecture["hparams"]["checkpoint_model_name"]
except Exception:
checkpoint_model_name = None
if (
state_model_name == model_name
or state_model_name == checkpoint_model_name
):
loaded_state_dict[state_name] = loaded_params["state_dict"][layer_name]
model.load_state_dict(loaded_state_dict)
model.name = model_name
model.data = architecture["data"]
model.target_key = model_name + "_" + architecture["data"]["target"]
try:
freeze_weights = architecture["freeze_weights"]
except KeyError:
freeze_weights = None
if freeze_weights:
for param in model.parameters():
param.requires_grad = False
model.eval()
if model_type == "mlp_predictor":
model.from_position = architecture["from_position"]
else:
model.position = architecture["position"]
model.input_key = model_name + "_" + architecture["data"]["input"]
return model
[docs]def building_models(architectures: List[Dict[str, Any]]) -> List[GranularBaseModel]:
"""Building models given architecture configurations.
Args:
architectures: list of architecture configurations.
Returns:
a list of models.
"""
return [build_model(architecture) for architecture in architectures]