Source code for gt4sd.frameworks.gflownet.dataloader.sampler

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import logging
from typing import Generator, Union

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset

from ..envs.graph_building_env import GraphBuildingEnvContext
from ..loss.trajectory_balance import TrajectoryBalance, TrajectoryBalanceModel
from .dataset import GFlowNetDataset, GFlowNetTask

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class SamplingIterator(IterableDataset): """Sampler to speed up training based on IterableDataset."""
[docs] def __init__( self, dataset: GFlowNetDataset, model: Union[nn.Module, TrajectoryBalanceModel], batch_size: int, context: GraphBuildingEnvContext, algorithm: TrajectoryBalance, task: GFlowNetTask, ratio: float = 0.5, stream: bool = True, device: str = "cuda", ) -> None: """ This class allows us to parallelise and train faster. By separating sampling data/the model and building torch geometric graphs from training the model, we can do the former in different processes, which is much faster since much of graph construction is CPU-bound. This sampler can handle offline and online data. Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/data. Args: dataset: a dataset instance. model: the model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker). batch_size: the number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task). context: the graph environment. algorithm: the training algorithm, e.g. a TrajectoryBalance instance. task: ConditionalTask that specifies the reward structure. ratio: the ratio of offline trajectories in the batch. stream: if true, data is sampled iid for every batch. Otherwise, this is a normal in-order dataset iterator. device: the device (cpu, cuda). """ self.data = dataset self.model = model self.batch_size = batch_size self.offline_batch_size = int(np.ceil(batch_size * ratio)) self.online_batch_size = int(np.floor(batch_size * (1 - ratio))) self.ratio = ratio self.ctx = context self.algo = algorithm self.task = task self.device = device self.stream = stream
[docs] def _idx_iterator(self) -> Generator: """Returns an iterator over the indices of the dataset. The batch can be offline or online. Yields: Batch of indexes. """ bs = self.offline_batch_size n = len(self.data) if self.stream: # if we're streaming data, just sample `offline_batch_size` indices while True: yield self.rng.integers(0, n, bs) # type: ignore else: # figure out which indices correspond to this worker worker_info = torch.utils.data.get_worker_info() # refactor this if worker_info is None: start = 0 end = n wid = -1 else: nw = worker_info.num_workers wid = worker_info.id start = int(np.floor(n / nw * wid)) end = int(np.ceil(n / nw * (wid + 1))) if end - start < bs: yield np.arange(start, end) return for i in range(start, end - bs, bs): yield np.arange(i, i + bs) if i + bs < end: yield np.arange(i + bs, end)
[docs] def __len__(self): """Length of the offline or online dataset. Returns: The length of the dataset. """ # if online if self.stream: return int(1e6) # if offline return len(self.data) # type: ignore
[docs] def sample_offline(self, idcs): """Samples offline data. Args: idcs: the indices of the data to sample. Returns: trajs: the trajectories. rewards: the rewards. """ # sample offline batch (mols, rewards) mols, _flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) # rewards flat_rewards = list(self.task.flat_reward_transform(_flat_rewards)) # type: ignore # build graphs graphs = [self.ctx.mol_to_graph(m) for m in mols] # use trajectory balance to sample trajectories trajs = self.algo.create_training_data_from_graphs(graphs) return trajs, flat_rewards
[docs] def predict_reward_model(self, trajs, flat_rewards, num_offline): """Predict rewards using the model. Args: trajs: the trajectories. flat_rewards: the rewards. num_offline: the number of offline trajectories. Returns: flat_rewards: the updated rewards. """ # The model can be trained to predict its own reward, # i.e. predict the output of cond_info_to_reward pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] flat_rewards += list(pred_reward) return flat_rewards
[docs] def predict_reward_task(self, trajs, flat_rewards, num_offline, is_valid): """Predict rewards using the task. Args: trajs: the trajectories. flat_rewards: the rewards. num_offline: the number of offline trajectories. is_valid: whether the trajectories are valid. Returns: flat_rewards: the updated rewards. """ # Otherwise, query the task for flat rewards valid_idcs = torch.tensor( [ i + num_offline for i in range(self.online_batch_size) if trajs[i + num_offline]["is_valid"] ] ).long() pred_reward = torch.zeros((self.online_batch_size)) # fetch the valid trajectories endpoints mols = [self.ctx.graph_to_mol(trajs[i]["traj"][-1][0]) for i in valid_idcs] # ask the task to compute their reward preds, m_is_valid = self.task.compute_flat_rewards(mols) # The task may decide some of the mols are invalid, we have to again filter those valid_idcs = valid_idcs[m_is_valid] _preds = torch.tensor(preds, dtype=torch.float32) pred_reward[valid_idcs - num_offline] = _preds is_valid[num_offline:] = False is_valid[valid_idcs] = True flat_rewards += list(pred_reward) # Override the is_valid key in case the task made some mols invalid for i in range(self.online_batch_size): trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() return trajs, flat_rewards
[docs] def sample_online(self, trajs, flat_rewards, cond_info, num_offline): """Sample on-policy data. Args: trajs: the trajectories. flat_rewards: the rewards. cond_info: the conditional information. num_offline: the number of offline trajectories. Returns: trajs: the updated trajectories. flat_rewards: the updated rewards. """ is_valid = torch.ones(cond_info["beta"].shape[0]).bool() with torch.no_grad(): trajs += self.algo.create_training_data_from_own_samples( self.model, # TODO: double-check the model here self.online_batch_size, cond_info["encoding"][num_offline:], ) # predict reward with model if self.algo.bootstrap_own_reward: flat_rewards = self.predict_reward_model(trajs, flat_rewards, num_offline) # predict reward with task else: trajs, flat_rewards = self.predict_reward_task( trajs, flat_rewards, num_offline, is_valid ) return trajs, flat_rewards
[docs] def set_seed(self): """Set the seed for the workers.""" worker_info = torch.utils.data.get_worker_info() wid = worker_info.id if worker_info is not None else 0 # set seed for each worker seed = np.random.default_rng(142857 + wid) self.rng = seed self.algo.rng = seed self.task.rng = seed self.ctx.device = self.device
[docs] def __iter__(self): """Build batch using online and offline data and multiple workers. Yields: batch of data using trajectories and rewards. """ # we need to set a different seed for each worker to sample different batches. # If we start with the same seed, each worker has the exact same copy of the data and will yield the same batch. self.set_seed() # import ipdb # ipdb.sset_trace() # iterate over the indices in the batch and yield a bunch of indexes for idcs in self._idx_iterator(): num_offline = idcs.shape[0] # This is in [1, self.offline_batch_size] # Sample conditional info such as temperature, trade-off weights, etc. cond_info = self.task.sample_conditional_information( num_offline + self.online_batch_size ) is_valid = torch.ones(cond_info["beta"].shape[0]).bool() # sample offline data trajs, flat_rewards = self.sample_offline(idcs) # Sample some on-policy data (sample online the model or the task) if self.online_batch_size > 0: # update trajectories and rewards with on-policy data trajs, flat_rewards = self.sample_online( trajs, flat_rewards, cond_info, num_offline ) # compute scalar rewards from conditional information & flat rewards rewards = self.task.cond_info_to_reward(cond_info, flat_rewards) # account for illegal actions rewards[torch.logical_not(is_valid)] = np.exp( self.algo.illegal_action_logreward ) # Construct batch using trajectories, rewards batch = self.algo.construct_batch(trajs, cond_info["encoding"], rewards) batch.num_offline = num_offline yield batch