GT4SD - training GFlowNets on generic task¶
Contents
Overview¶
This notebook shows the basic usage of the GFlowNetwork (GFN) framework on a generic task.
We provide an example of how to setup GFN to train on QM9 in examples/gflownet/main_qm9.py
.
The implementation is adapted from: https://github.com/recursionpharma/gflownet.
The user has to define (at least) 2 main components:
a dataset compatible with
GFlowNetDataset
(seegt4sd/frameworks/gflownet/tests/qm9.py
)a task compatible with
GFlowNetTask
where defining the reward function (seegt4sd/frameworks/gflownet/tests/qm9.py
).
Here we are assuming that:
an environment compatible with
GraphBuildingEnvironment
for graph-based problems is implemented inenvs/graph_building_env.py
;a context compatible with
GraphBuildingEnvContext
to specify how to use the basic building blocks in the environment is implemented inenvs/mol_building_env.py
;action in the environment is discrete and prescribed by
GraphActionCategorical
for graph-based problems inenvs/graph_building_env.py
.
A note on the requirements¶
GFN relies on pytorch_lightning
and pytorch_geometric
.
We recommend training GFN on GPU and checking the pytorch_geomtric requirements for your environment.
Debugging¶
Training GFN can be a long process. To debug your training pipeline, set development=True
. This will activate fast_dev_run
functionality in the pytorch_lightning trainer.
If training gets stuck and the dataloader does not yield data, set num_workers=0
.
Minimal training example¶
Here we provide a minimal traninng script. We implemented a dataset and task in the examples
folder and rely on environment, context and training routines in frameworks
.
from gt4sd.frameworks.gflownet.arg_parser.parser import parse_arguments_from_config
from gt4sd.frameworks.gflownet.envs.graph_building_env import GraphBuildingEnv
from gt4sd.frameworks.gflownet.envs.mol_building_env import MolBuildingEnvContext
from gt4sd.frameworks.gflownet.tests.qm9 import QM9Dataset, QM9GapTask
from gt4sd.frameworks.gflownet.train.core import train_gflownet
def main():
"""Run basic GFN training on QM9."""
configuration = {"dataset": "qm9", "dataset_path": "/GFN/qm9.h5", "device": "cpu"}
# add user configuration
configuration.update(vars(parse_arguments_from_config()))
# build the environment and context
environment = GraphBuildingEnv()
context = MolBuildingEnvContext()
# build the dataset
dataset = QM9Dataset(configuration["dataset_path"], target="gap")
# build the task
task = QM9GapTask(
configuration=configuration,
dataset=dataset,
)
# train gflownet
train_gflownet(
configuration=configuration,
dataset=dataset,
environment=environment,
context=context,
task=task,
)
if __name__ == "__main__":
main()