GT4SD - training GFlowNets on generic task


This notebook shows the basic usage of the GFlowNetwork (GFN) framework on a generic task. We provide an example of how to setup GFN to train on QM9 in examples/gflownet/ The implementation is adapted from:

The user has to define (at least) 2 main components:

  • a dataset compatible with GFlowNetDataset (see gt4sd/frameworks/gflownet/tests/

  • a task compatible with GFlowNetTask where defining the reward function (see gt4sd/frameworks/gflownet/tests/

Here we are assuming that:

  • an environment compatible with GraphBuildingEnvironment for graph-based problems is implemented in envs/;

  • a context compatible with GraphBuildingEnvContext to specify how to use the basic building blocks in the environment is implemented in envs/;

  • action in the environment is discrete and prescribed by GraphActionCategorical for graph-based problems in envs/

A note on the requirements

GFN relies on pytorch_lightning and pytorch_geometric. We recommend training GFN on GPU and checking the pytorch_geomtric requirements for your environment.


Training GFN can be a long process. To debug your training pipeline, set development=True. This will activate fast_dev_run functionality in the pytorch_lightning trainer. If training gets stuck and the dataloader does not yield data, set num_workers=0.

Minimal training example

Here we provide a minimal traninng script. We implemented a dataset and task in the examples folder and rely on environment, context and training routines in frameworks.

from gt4sd.frameworks.gflownet.arg_parser.parser import parse_arguments_from_config
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext
from gt4sd.frameworks.gflownet.tests.qm9 import QM9Dataset, QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet

def main():
    """Run basic GFN training on QM9."""

    configuration = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}
    # add user configuration

    # build the environment and context
    environment = GraphBuildingEnv()
    context = MolBuildingEnvContext()
    # build the dataset
    dataset = QM9Dataset(configuration["dataset_path"], target="gap")
    # build the task
    task = QM9GapTask(
    # train gflownet

if __name__ == "__main__":