gt4sd.frameworks.gflownet.dataloader.sampler module

Summary

Classes:

SamplingIterator

Sampler to speed up training based on IterableDataset.

Reference

class SamplingIterator(dataset, model, batch_size, context, algorithm, task, ratio=0.5, stream=True, device='cuda')[source]

Bases: IterableDataset

Sampler to speed up training based on IterableDataset.

__init__(dataset, model, batch_size, context, algorithm, task, ratio=0.5, stream=True, device='cuda')[source]

This class allows us to parallelise and train faster. By separating sampling data/the model and building torch geometric graphs from training the model, we can do the former in different processes, which is much faster since much of graph construction is CPU-bound. This sampler can handle offline and online data.

Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/data.

Parameters
  • dataset (GFlowNetDataset) – a dataset instance.

  • model (Union[Module, TrajectoryBalanceModel]) – the model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker).

  • batch_size (int) – the number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task).

  • context (GraphBuildingEnvContext) – the graph environment.

  • algorithm (TrajectoryBalance) – the training algorithm, e.g. a TrajectoryBalance instance.

  • task (GFlowNetTask) – ConditionalTask that specifies the reward structure.

  • ratio (float) – the ratio of offline trajectories in the batch.

  • stream (bool) – if true, data is sampled iid for every batch. Otherwise, this is a normal in-order dataset iterator.

  • device (str) – the device (cpu, cuda).

_idx_iterator()[source]

Returns an iterator over the indices of the dataset. The batch can be offline or online.

Yields

Batch of indexes.

Return type

Generator

__len__()[source]

Length of the offline or online dataset.

Returns

The length of the dataset.

sample_offline(idcs)[source]

Samples offline data.

Parameters

idcs – the indices of the data to sample.

Returns

the trajectories. rewards: the rewards.

Return type

trajs

predict_reward_model(trajs, flat_rewards, num_offline)[source]

Predict rewards using the model.

Parameters
  • trajs – the trajectories.

  • flat_rewards – the rewards.

  • num_offline – the number of offline trajectories.

Returns

the updated rewards.

Return type

flat_rewards

predict_reward_task(trajs, flat_rewards, num_offline, is_valid)[source]

Predict rewards using the task.

Parameters
  • trajs – the trajectories.

  • flat_rewards – the rewards.

  • num_offline – the number of offline trajectories.

  • is_valid – whether the trajectories are valid.

Returns

the updated rewards.

Return type

flat_rewards

sample_online(trajs, flat_rewards, cond_info, num_offline)[source]

Sample on-policy data.

Parameters
  • trajs – the trajectories.

  • flat_rewards – the rewards.

  • cond_info – the conditional information.

  • num_offline – the number of offline trajectories.

Returns

the updated trajectories. flat_rewards: the updated rewards.

Return type

trajs

set_seed()[source]

Set the seed for the workers.

__iter__()[source]

Build batch using online and offline data and multiple workers.

Yields

batch of data using trajectories and rewards.

__annotations__ = {}
__doc__ = 'Sampler to speed up training based on IterableDataset.'
__module__ = 'gt4sd.frameworks.gflownet.dataloader.sampler'
__parameters__ = ()