gt4sd.frameworks.gflownet.dataloader.sampler module¶
Summary¶
Classes:
Sampler to speed up training based on IterableDataset. |
Reference¶
- class SamplingIterator(dataset, model, batch_size, context, algorithm, task, ratio=0.5, stream=True, device='cuda')[source]¶
Bases:
IterableDataset
Sampler to speed up training based on IterableDataset.
- __init__(dataset, model, batch_size, context, algorithm, task, ratio=0.5, stream=True, device='cuda')[source]¶
This class allows us to parallelise and train faster. By separating sampling data/the model and building torch geometric graphs from training the model, we can do the former in different processes, which is much faster since much of graph construction is CPU-bound. This sampler can handle offline and online data.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/data.
- Parameters
dataset (
GFlowNetDataset
) – a dataset instance.model (
Union
[Module
,TrajectoryBalanceModel
]) – the model we sample from (must be on CUDA already or share_memory() must be called so that parameters are synchronized between each worker).batch_size (
int
) – the number of trajectories, each trajectory will be comprised of many graphs, so this is _not_ the batch size in terms of the number of graphs (that will depend on the task).context (
GraphBuildingEnvContext
) – the graph environment.algorithm (
TrajectoryBalance
) – the training algorithm, e.g. a TrajectoryBalance instance.task (
GFlowNetTask
) – ConditionalTask that specifies the reward structure.ratio (
float
) – the ratio of offline trajectories in the batch.stream (
bool
) – if true, data is sampled iid for every batch. Otherwise, this is a normal in-order dataset iterator.device (
str
) – the device (cpu, cuda).
- _idx_iterator()[source]¶
Returns an iterator over the indices of the dataset. The batch can be offline or online.
- Yields
Batch of indexes.
- Return type
Generator
- sample_offline(idcs)[source]¶
Samples offline data.
- Parameters
idcs – the indices of the data to sample.
- Returns
the trajectories. rewards: the rewards.
- Return type
trajs
- predict_reward_model(trajs, flat_rewards, num_offline)[source]¶
Predict rewards using the model.
- Parameters
trajs – the trajectories.
flat_rewards – the rewards.
num_offline – the number of offline trajectories.
- Returns
the updated rewards.
- Return type
flat_rewards
- predict_reward_task(trajs, flat_rewards, num_offline, is_valid)[source]¶
Predict rewards using the task.
- Parameters
trajs – the trajectories.
flat_rewards – the rewards.
num_offline – the number of offline trajectories.
is_valid – whether the trajectories are valid.
- Returns
the updated rewards.
- Return type
flat_rewards
- sample_online(trajs, flat_rewards, cond_info, num_offline)[source]¶
Sample on-policy data.
- Parameters
trajs – the trajectories.
flat_rewards – the rewards.
cond_info – the conditional information.
num_offline – the number of offline trajectories.
- Returns
the updated trajectories. flat_rewards: the updated rewards.
- Return type
trajs
- __iter__()[source]¶
Build batch using online and offline data and multiple workers.
- Yields
batch of data using trajectories and rewards.
- __annotations__ = {}¶
- __doc__ = 'Sampler to speed up training based on IterableDataset.'¶
- __module__ = 'gt4sd.frameworks.gflownet.dataloader.sampler'¶
- __parameters__ = ()¶