Source code for gt4sd.frameworks.gflownet.dataloader.data_module

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Data module for gflownet."""

import logging
from typing import Any, Dict, Optional

import sentencepiece as _sentencepiece
import torch as _torch
import tensorflow as _tensorflow
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
from torch.utils.data import DataLoader  # , Subset, random_split

from ..dataloader.dataset import GFlowNetDataset, GFlowNetTask
from ..envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext
from ..loss.trajectory_balance import TrajectoryBalance
from ..ml.models import MODEL_FACTORY
from .sampler import SamplingIterator

# imports that have to be loaded before lightning to avoid segfaults
_sentencepiece
_tensorflow
_torch

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class GFlowNetDataModule(pl.LightningDataModule): """Data module from gflownet."""
[docs] def __init__( self, configuration: Dict[str, Any], dataset: GFlowNetDataset, environment: GraphBuildingEnv, context: GraphBuildingEnvContext, task: GFlowNetTask, algorithm: TrajectoryBalance, model: Optional[nn.Module] = None, ) -> None: """Construct GFlowNetDataModule. The module assumes a model and algorithm factory/registry. The user should provide a dataset, environment, context for the environment, and task. Args: configuration: configuration dictionary. dataset: dataset. environment: environment for graph building. context: context environment. task: generic task. algorithm: loss function. model: model used to generate data with the sampling iterator. It can be a custom model or the same as the one used in the algorithm. """ super().__init__() self.hps = configuration # if model is given if model: self.sampling_model = model else: self.sampling_model = MODEL_FACTORY[self.hps["sampling_model"]]( self.hps, context ) self.algo = algorithm self.env = environment self.ctx = context self.dataset = dataset self.task = task self.sampling_iterator = self.hps["sampling_iterator"] self.batch_size = self.hps["batch_size"] self.num_workers = self.hps["num_workers"] self.device = self.hps["device"] self.rng = self.hps["rng"] self.ratio = self.hps["ratio"] self.mb_size = self.hps["global_batch_size"]
[docs] def prepare_data(self) -> None: """Prepare training and test dataset.""" self.train_dataset = self.dataset self.val_dataset = self.dataset self.test_dataset = self.dataset
[docs] def setup(self, stage: Optional[str]) -> None: # type:ignore """Setup the data module. Args: stage: stage considered. Defaults to None. """ ll = self.dataset.get_len() ixs = np.arange(ll) self.rng.shuffle(ixs) thresh = int(np.floor(self.ratio * ll)) self.ix_train = ixs[: int(0.9 * thresh)] self.ix_val = ixs[int(0.9 * thresh) : thresh] self.ix_test = ixs[thresh:] if stage == "fit" or stage is None: self.train_dataset.set_indexes(self.ix_train) # type: ignore self.val_dataset.set_indexes(self.ix_val) # type: ignore if stage == "test" or stage is None: self.test_dataset.set_indexes(self.ix_test) # type: ignore if stage == "predict" or stage is None: self.test_dataset.set_indexes(self.ix_test) # type: ignore logger.info( f"number of data points used for training: {len(self.train_dataset)}" ) logger.info(f"number of data points used for testing: {len(self.test_dataset)}") logger.info( f"testing proportion: {len(self.test_dataset) / (len(self.test_dataset) + len(self.train_dataset))}" )
[docs] def train_dataloader(self) -> DataLoader: """Get a data loader for training. Returns: a training data loader. """ if self.sampling_iterator: iterator = SamplingIterator( self.train_dataset, self.sampling_model, self.mb_size * 2, self.ctx, self.algo, self.task, device=self.device, ) batch_size = None else: iterator = self.train_dataset # type: ignore batch_size = self.batch_size return DataLoader( iterator, batch_size=batch_size, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, )
[docs] def val_dataloader(self) -> DataLoader: """Get a data loader for validation. Returns: a validation data loader. """ if self.sampling_iterator: iterator = SamplingIterator( self.val_dataset, self.sampling_model, self.mb_size, self.ctx, self.algo, self.task, ratio=1, stream=False, device=self.device, ) batch_size = None else: iterator = self.val_dataset # type: ignore batch_size = self.batch_size return DataLoader( iterator, batch_size=batch_size, num_workers=self.num_workers, persistent_workers=self.num_workers > 0, )
def test_dataloader(self) -> DataLoader: """Get a data loader for testing. Returns: a testing data loader. """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, )
[docs] def predict_dataloader(self) -> DataLoader: """Get a data loader for prediction. Returns: a prediction data loader. """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, )