Source code for gt4sd.frameworks.granular.ml.models.base_model

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Base model for granular."""

from abc import abstractmethod
from argparse import ArgumentParser
from typing import Any, Dict, List, Tuple

import torch
from torch import nn
from torch.distributions import Distribution


[docs]class GranularBaseModel(nn.Module): """Base model class.""" position: int from_position: List[int]
[docs] def __init__(self, name: str, data: Dict[str, str], *args, **kwargs) -> None: """Construct GranularBaseModel. Args: name: model name. data: data name mappings. """ super().__init__() self.name = name self.data = data
[docs] def forward(self, x: Any, *args, **kwargs) -> Any: """Forward pass in the model. Args: x: model input. Returns: model output. """ return self._run_step(x)
[docs] @abstractmethod def _run_step(self, x: Any, *args, **kwargs) -> Any: """Run a step in the model. Args: x: model input. Returns: model step output. """ pass
[docs] @abstractmethod def step( self, input_data: Any, target_data: Any, device: str = "cpu", current_epoch: int = 0, *args, **kwargs, ) -> Tuple[Any, Any, Any]: """Training step for the model. Args: input_data: input for the step. target_data: target for the step. device: string representing the device to use. Defaults to "cpu". current_epoch: current epoch. Defaults to 0. Returns: a tuple containing the step output, the loss and the logs for the module. """ pass
[docs] @abstractmethod def val_step( self, input_data: Any, target_data: Any, device: str = "cpu", current_epoch: int = 0, *args, **kwargs, ) -> Any: """Validation step for the model. Args: input_data: input for the step. target_data: target for the step. device: string representing the device to use. Defaults to "cpu". current_epoch: current epoch. Defaults to 0. Returns: a tuple containing the step output, the loss and the logs for the module. """ pass
[docs] @staticmethod def add_model_specific_args( parent_parser: ArgumentParser, name: str, *args, **kwargs ) -> ArgumentParser: """Adding to a parser model specific arguments. Args: parent_parser: patent parser. name: model name. Returns: updated parser. """ return parent_parser
[docs]class GranularEncoderDecoderModel(GranularBaseModel): """Autoencoder model class.""" latent_size: int
[docs] @abstractmethod def decode(self, z: Any, *args, **kwargs) -> Any: """Decode a latent space point. Args: z: latent point. Returns: decoded sample. """ pass
[docs] @abstractmethod def encode(self, x: Any, *args, **kwargs) -> Any: """Encode a sample. Args: x: input sample. Returns: latent encoding. """ pass
[docs] def encode_decode(self, x: Any, *args, **kwargs) -> Any: """Encode and decode a sample. Args: x: input sample. Returns: decoded sample. """ z = self.encode(x) return self.decode(z)
[docs] def inference(self, z: Any, *args, **kwargs) -> Any: """Run the model in inference mode. Args: z: sample. Returns: generated output. """ return self.decode(z)
[docs] def sample( self, mu: torch.Tensor, log_var: torch.Tensor ) -> Tuple[Distribution, Distribution, torch.Tensor]: """Sample a point from a given mean and average following a normal log-likelihood. Args: mu: mean tensor. log_var: log varian tensor. Returns: a tuple containing standard normal, localized normal and the sampled point. """ std = torch.exp(log_var / 2.0) p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) q = torch.distributions.Normal(mu, std) z = q.rsample() return p, q, z