gt4sd.frameworks.gflownet.dataloader.data_module module

Data module for gflownet.

Summary

Classes:

GFlowNetDataModule

Data module from gflownet.

Reference

class GFlowNetDataModule(configuration, dataset, environment, context, task, algorithm, model=None)[source]

Bases: LightningDataModule

Data module from gflownet.

__init__(configuration, dataset, environment, context, task, algorithm, model=None)[source]

Construct GFlowNetDataModule.

The module assumes a model and algorithm factory/registry. The user should provide a dataset, environment, context for the environment, and task.

Parameters
  • configuration (Dict[str, Any]) – configuration dictionary.

  • dataset (GFlowNetDataset) – dataset.

  • environment (GraphBuildingEnv) – environment for graph building.

  • context (GraphBuildingEnvContext) – context environment.

  • task (GFlowNetTask) – generic task.

  • algorithm (TrajectoryBalance) – loss function.

  • model (Optional[Module, None]) – model used to generate data with the sampling iterator. It can be a custom model or the same as the one used in the algorithm.

prepare_data()[source]

Prepare training and test dataset.

Return type

None

setup(stage)[source]

Setup the data module.

Parameters

stage (Optional[str, None]) – stage considered. Defaults to None.

Return type

None

train_dataloader()[source]

Get a data loader for training.

Return type

DataLoader

Returns

a training data loader.

val_dataloader()[source]

Get a data loader for validation.

Return type

DataLoader

Returns

a validation data loader.

__annotations__ = {}
__doc__ = 'Data module from gflownet.'
__module__ = 'gt4sd.frameworks.gflownet.dataloader.data_module'
predict_dataloader()[source]

Get a data loader for prediction.

Return type

DataLoader

Returns

a prediction data loader.