#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Model module."""
import ast
import copy
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch_geometric.data as gd
from ..dataloader.dataset import GFlowNetDataset, GFlowNetTask
from ..envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext
# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class GFlowNetAlgorithm:
"""A generic algorithm for gflownet."""
[docs] def compute_batch_losses(
self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0
) -> Tuple[float, Dict[str, float]]:
"""Computes the loss for a batch of data, and proves logging informations.
Args:
model: the model being trained or evaluated.
batch: a batch of graphs.
num_bootstrap: the number of trajectories with reward targets in the batch (if applicable).
Returns:
loss: the loss for that batch.
info: logged information about model predictions.
"""
raise NotImplementedError()
[docs]class GFlowNetModule(pl.LightningModule):
"""Module from gflownet."""
[docs] def __init__(
self,
configuration: Dict[str, Any],
dataset: GFlowNetDataset,
environment: GraphBuildingEnv,
context: GraphBuildingEnvContext,
task: GFlowNetTask,
algorithm: GFlowNetAlgorithm,
model: nn.Module,
) -> None:
"""Construct GFNModule.
Args:
configuration: the configuration of the module.
dataset: the dataset to use.
environment: the environment to use.
context: the context to use.
task: the task to solve.
algorithm: algorithm (trajectory_balance or td_loss).
model: architecture (graph_transformer_gfn or graph_transformer).
"""
super().__init__()
self.hps = configuration
# self.save_hyperparameters()
self.env = environment
self.ctx = context
self.dataset = dataset
self.model = model
self.algo = algorithm
self.task = task
self.test_output_path = self.hps["test_output_path"]
self.rng = self.hps["rng"]
self.mb_size = self.hps["global_batch_size"]
self.clip_grad_param = self.hps["clip_grad_param"]
self.sampling_tau = self.hps["sampling_tau"]
[docs] def training_step( # type:ignore
self,
batch: gd.Batch,
batch_idx: int,
optimizer_idx: int,
*args: Any,
**kwargs: Any,
) -> Dict[str, Any]:
"""Training step implementation.
Args:
batch: batch representation.
epoch_idx: epoch index.
batch_idx: batch index.
Returns:
loss and logs.
"""
logs = dict()
loss, info = self.algo.compute_batch_losses(
self.model, batch, num_bootstrap=self.mb_size
)
logs.update(
{
self.model.name + f"/{k}": v.detach() if hasattr(v, "item") else v # type: ignore
for k, v in info.items()
}
)
logs.update({"total_loss": loss.item()}) # type: ignore
# logs for step
_logs = {f"train/{k}": v for k, v in logs.items()}
self.log_dict(_logs, on_step=True, on_epoch=True, prog_bar=True)
self.log("train_loss", loss)
# logs per epoch
logs_epoch = {f"train_epoch/{k}": v for k, v in logs.items()}
logs_epoch["step"] = self.current_epoch
self.log_dict(
logs_epoch,
on_step=False,
on_epoch=True,
prog_bar=False,
)
return {"loss": loss, "logs": logs}
[docs] def training_step_end(self, batch_parts):
for i in self.model.parameters():
self.clip_grad_callback(i)
if self.sampling_tau > 0:
for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))
[docs] def validation_step( # type:ignore
self, batch: gd.Batch, batch_idx: int, *args: Any, **kwargs: Any
) -> Dict[str, Any]:
"""Validation step implementation.
Args:
batch: batch representation.
Returns:
loss and logs.
"""
loss = 0.0
logs = dict()
loss, info = self.algo.compute_batch_losses(
self.model, batch, num_bootstrap=batch.num_offline
)
logs.update({k: v if hasattr(v, "item") else v for k, v in info.items()})
logs.update({"total_loss": loss})
self.log_dict(
{f"val/{k}": v for k, v in logs.items()},
on_step=True,
on_epoch=True,
prog_bar=False,
)
self.log("val_loss", loss)
return {"loss": loss, "logs": logs}
def test_step( # type:ignore
self, batch: Any, batch_idx: int
) -> Dict[str, Any]:
"""Testing step implementation.
Args:
batch: batch representation.
batch_idx: batch index, unused.
Returns:
loss, logs, and latent encodings.
"""
loss = 0.0
logs = dict()
loss, info = self.algo.compute_batch_losses(
self.model, batch, num_bootstrap=batch.num_offline
)
logs.update({k: v if hasattr(v, "item") else v for k, v in info.items()})
logs.update({"total_loss": loss})
self.log_dict(
{f"test/{k}": v for k, v in logs.items()},
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.log("test_loss", loss)
return {"loss": loss, "logs": logs}
[docs] def prediction_step(self, batch) -> torch.Tensor:
"""Inference step.
Args:
batch: batch data.
Returns:
output forward.
"""
return self(batch)
[docs] def train_epoch_end(self, outputs: List[Dict[str, Any]]):
"""Train epoch end.
Args:
outputs: list of outputs epoch.
Returns:
"""
pass
# change the following to new implementation
def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type:ignore
"""Callback called at the end of an epoch on test outputs.
Dump encodings and targets for the test set.
Args:
outputs: outputs for test batches.
"""
z = {}
targets = {}
z_keys = [key for key in outputs[0]["z"]]
targets_keys = [key for key in outputs[0]["targets"]]
for key in z_keys:
z[key] = (
torch.cat(
[torch.squeeze(an_output["z"][key]) for an_output in outputs], dim=0
)
.detach()
.cpu()
.numpy()
)
for key in targets_keys:
targets[key] = (
torch.cat(
[torch.squeeze(an_output["targets"][key]) for an_output in outputs],
dim=0,
)
.detach()
.cpu()
.numpy()
)
pd.to_pickle(z, f"{self.test_output_path}{os.path.sep}z_build.pkl")
pd.to_pickle(targets, f"{self.test_output_path}{os.path.sep}targets.pkl")