gt4sd.training_pipelines.pytorch_lightning.gflownet.core module

GFlowNet training utilities.

Summary

Classes:

GFlowNetDataArguments

Arguments related to data.

GFlowNetModelArguments

Arguments related to model.

GFlowNetPytorchLightningTrainingArguments

Arguments related to pytorch lightning trainer.

GFlowNetSavingArguments

Saving arguments related to Granular trainer.

GFlowNetTrainingPipeline

gflownet training pipelines.

Reference

class GFlowNetTrainingPipeline[source]

Bases: PyTorchLightningTrainingPipeline

gflownet training pipelines.

train(pl_trainer_args, model_args, dataset_args, dataset, environment, context, task)[source]

Generic training function for PyTorch Lightning-based training.

Parameters
  • pl_trainer_args (Dict[str, Any]) – pytorch lightning trainer arguments passed to the configuration.

  • model_args (Dict[str, Union[float, str, int]]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Union[float, str, int]]) – dataset arguments passed to the configuration.

  • dataset (GFlowNetDataset) – dataset to be used for training.

  • environment (GraphBuildingEnv) – environment to be used for training.

  • context (GraphBuildingEnvContext) – context to be used for training.

  • task (GFlowNetTask) – task to be used for training.

Return type

None

get_data_and_model_modules(model_args, dataset_args, pl_trainer_args, dataset, environment, context, task)[source]

Get data and model modules for training.

Parameters
  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

Tuple[LightningDataModule, LightningModule]

Returns

the data and model modules.

__doc__ = 'gflownet training pipelines.'
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'
class GFlowNetPytorchLightningTrainingArguments(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, trainer_log_every_n_steps=50, epochs=3, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple', learning_rate=0.0001, test_output_path='./test', num_workers=0, log_dir='./log/', num_training_steps=1000, validate_every=1000, seed=142857, device='cpu', development_mode=False)[source]

Bases: PytorchLightningTrainingArguments

Arguments related to pytorch lightning trainer.

__name__ = 'GFlowNetPytorchLightningTrainingArguments'
strategy: Optional[str] = 'ddp'
accumulate_grad_batches: int = 1
trainer_log_every_n_steps: int = 50
val_check_interval: int = 5000
save_dir: Optional[str] = 'logs'
basename: Optional[str] = 'lightning_logs'
gradient_clip_val: float = 0.0
limit_val_batches: int = 500
log_every_n_steps: int = 500
max_epochs: int = 3
epochs: int = 3
resume_from_checkpoint: Optional[str] = None
gpus: Optional[int] = -1
monitor: Optional[str] = None
save_last: Optional[bool] = None
save_top_k: int = 1
mode: str = 'min'
every_n_train_steps: Optional[int] = None
check_val_every_n_epoch: Optional[int] = 5
auto_lr_find: bool = True
profiler: Optional[str] = 'simple'
learning_rate: float = 0.0001
num_workers: int = 0
log_dir: str = './log/'
num_training_steps: int = 1000
validate_every: int = 1000
seed: int = 142857
device: str = 'cpu'
development_mode: bool = False
__annotations__ = {'accumulate_grad_batches': <class 'int'>, 'auto_lr_find': <class 'bool'>, 'basename': typing.Optional[str], 'check_val_every_n_epoch': typing.Optional[int], 'development_mode': <class 'bool'>, 'device': <class 'str'>, 'epochs': <class 'int'>, 'every_n_epochs': 'Optional[int]', 'every_n_train_steps': typing.Optional[int], 'gpus': typing.Optional[int], 'gradient_clip_val': <class 'float'>, 'learning_rate': <class 'float'>, 'limit_val_batches': <class 'int'>, 'log_dir': <class 'str'>, 'log_every_n_steps': <class 'int'>, 'max_epochs': <class 'int'>, 'mode': <class 'str'>, 'monitor': typing.Optional[str], 'num_training_steps': <class 'int'>, 'num_workers': <class 'int'>, 'profiler': typing.Optional[str], 'resume_from_checkpoint': typing.Optional[str], 'save_dir': typing.Optional[str], 'save_last': typing.Optional[bool], 'save_top_k': <class 'int'>, 'seed': <class 'int'>, 'strategy': typing.Optional[str], 'test_output_path': typing.Optional[str], 'trainer_log_every_n_steps': <class 'int'>, 'val_check_interval': <class 'int'>, 'validate_every': <class 'int'>}
__dataclass_fields__ = {'accumulate_grad_batches': Field(name='accumulate_grad_batches',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Accumulates grads every k batches or as set up in the dict.'}),kw_only=False,_field_type=_FIELD), 'auto_lr_find': Field(name='auto_lr_find',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select whether to run a learning rate finder to try to optimize initial learning for faster convergence.'}),kw_only=False,_field_type=_FIELD), 'basename': Field(name='basename',type=typing.Optional[str],default='lightning_logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Experiment name.'}),kw_only=False,_field_type=_FIELD), 'check_val_every_n_epoch': Field(name='check_val_every_n_epoch',type=typing.Optional[int],default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of validation epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'development_mode': Field(name='development_mode',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to run in development mode. '}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The device to use.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'every_n_epochs': Field(name='every_n_epochs',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_train_steps': Field(name='every_n_train_steps',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training steps between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'gpus': Field(name='gpus',type=typing.Optional[int],default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of gpus to train on.'}),kw_only=False,_field_type=_FIELD), 'gradient_clip_val': Field(name='gradient_clip_val',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient clipping value.'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=0.0001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate.'}),kw_only=False,_field_type=_FIELD), 'limit_val_batches': Field(name='limit_val_batches',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How much of validation dataset to check.'}),kw_only=False,_field_type=_FIELD), 'log_dir': Field(name='log_dir',type=<class 'str'>,default='./log/',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The directory to save logs.'}),kw_only=False,_field_type=_FIELD), 'log_every_n_steps': Field(name='log_every_n_steps',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to log within steps.'}),kw_only=False,_field_type=_FIELD), 'max_epochs': Field(name='max_epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'mode': Field(name='mode',type=<class 'str'>,default='min',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'monitor': Field(name='monitor',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'num_training_steps': Field(name='num_training_steps',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of training steps.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'number of workers. Defaults to 1.'}),kw_only=False,_field_type=_FIELD), 'profiler': Field(name='profiler',type=typing.Optional[str],default='simple',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'To profile individual steps during training and assist in identifying bottlenecks.'}),kw_only=False,_field_type=_FIELD), 'resume_from_checkpoint': Field(name='resume_from_checkpoint',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path/URL of the checkpoint from which training is resumed.'}),kw_only=False,_field_type=_FIELD), 'save_dir': Field(name='save_dir',type=typing.Optional[str],default='logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Save directory for logs and output.'}),kw_only=False,_field_type=_FIELD), 'save_last': Field(name='save_last',type=typing.Optional[bool],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'When True, always saves the model at the end of the epoch to a file last.ckpt'}),kw_only=False,_field_type=_FIELD), 'save_top_k': Field(name='save_top_k',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The best k models according to the quantity monitored will be saved.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=142857,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The random seed.'}),kw_only=False,_field_type=_FIELD), 'strategy': Field(name='strategy',type=typing.Optional[str],default='ddp',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training strategy.'}),kw_only=False,_field_type=_FIELD), 'test_output_path': Field(name='test_output_path',type=typing.Optional[str],default='./test',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save latent encodings and predictions for the test set when an epoch ends.'}),kw_only=False,_field_type=_FIELD), 'trainer_log_every_n_steps': Field(name='trainer_log_every_n_steps',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'log every k steps.'}),kw_only=False,_field_type=_FIELD), 'val_check_interval': Field(name='val_check_interval',type=<class 'int'>,default=5000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': ' How often to check the validation set.'}),kw_only=False,_field_type=_FIELD), 'validate_every': Field(name='validate_every',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of training steps between validation.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to pytorch lightning trainer.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, trainer_log_every_n_steps=50, epochs=3, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple', learning_rate=0.0001, test_output_path='./test', num_workers=0, log_dir='./log/', num_training_steps=1000, validate_every=1000, seed=142857, device='cpu', development_mode=False)
__match_args__ = ('strategy', 'accumulate_grad_batches', 'val_check_interval', 'save_dir', 'basename', 'gradient_clip_val', 'limit_val_batches', 'log_every_n_steps', 'max_epochs', 'resume_from_checkpoint', 'gpus', 'monitor', 'save_last', 'save_top_k', 'mode', 'every_n_train_steps', 'every_n_epochs', 'trainer_log_every_n_steps', 'epochs', 'check_val_every_n_epoch', 'auto_lr_find', 'profiler', 'learning_rate', 'test_output_path', 'num_workers', 'log_dir', 'num_training_steps', 'validate_every', 'seed', 'device', 'development_mode')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'
__repr__()

Return repr(self).

class GFlowNetModelArguments(algorithm='trajectory_balance', context=None, environment=None, model='graph_transformer_gfn', sampling_model='graph_transformer_gfn', task='qm9', bootstrap_own_reward=False, num_emb=128, num_layers=4, tb_epsilon=1e-10, illegal_action_logreward=-50.0, reward_loss_multiplier=1.0, temperature_sample_dist='uniform', temperature_dist_params='(.5, 32)', weight_decay=1e-08, momentum=0.9, adam_eps=1e-08, lr_decay=20000, z_lr_decay=20000, clip_grad_type='norm', clip_grad_param=10.0, random_action_prob=0.001, sampling_tau=0.0, max_nodes=9, num_offline=10)[source]

Bases: TrainingPipelineArguments

Arguments related to model.

__name__ = 'GFlowNetModelArguments'
algorithm: str = 'trajectory_balance'
context: str = None
environment: str = None
model: str = 'graph_transformer_gfn'
sampling_model: str = 'graph_transformer_gfn'
task: str = 'qm9'
bootstrap_own_reward: bool = False
num_emb: int = 128
num_layers: int = 4
tb_epsilon: float = 1e-10
illegal_action_logreward: float = -50.0
reward_loss_multiplier: float = 1.0
temperature_sample_dist: str = 'uniform'
temperature_dist_params: str = '(.5, 32)'
weight_decay: float = 1e-08
momentum: float = 0.9
adam_eps: float = 1e-08
lr_decay: float = 20000
z_lr_decay: float = 20000
clip_grad_type: str = 'norm'
clip_grad_param: float = 10.0
random_action_prob: float = 0.001
sampling_tau: float = 0.0
max_nodes: int = 9
num_offline: int = 10
__annotations__ = {'adam_eps': <class 'float'>, 'algorithm': <class 'str'>, 'bootstrap_own_reward': <class 'bool'>, 'clip_grad_param': <class 'float'>, 'clip_grad_type': <class 'str'>, 'context': <class 'str'>, 'environment': <class 'str'>, 'illegal_action_logreward': <class 'float'>, 'lr_decay': <class 'float'>, 'max_nodes': <class 'int'>, 'model': <class 'str'>, 'momentum': <class 'float'>, 'num_emb': <class 'int'>, 'num_layers': <class 'int'>, 'num_offline': <class 'int'>, 'random_action_prob': <class 'float'>, 'reward_loss_multiplier': <class 'float'>, 'sampling_model': <class 'str'>, 'sampling_tau': <class 'float'>, 'task': <class 'str'>, 'tb_epsilon': <class 'float'>, 'temperature_dist_params': <class 'str'>, 'temperature_sample_dist': <class 'str'>, 'weight_decay': <class 'float'>, 'z_lr_decay': <class 'float'>}
__dataclass_fields__ = {'adam_eps': Field(name='adam_eps',type=<class 'float'>,default=1e-08,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The adam epsilon. '}),kw_only=False,_field_type=_FIELD), 'algorithm': Field(name='algorithm',type=<class 'str'>,default='trajectory_balance',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The algorithm to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'bootstrap_own_reward': Field(name='bootstrap_own_reward',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to bootstrap the own reward. '}),kw_only=False,_field_type=_FIELD), 'clip_grad_param': Field(name='clip_grad_param',type=<class 'float'>,default=10.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The clip grad param. '}),kw_only=False,_field_type=_FIELD), 'clip_grad_type': Field(name='clip_grad_type',type=<class 'str'>,default='norm',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The clip grad type. '}),kw_only=False,_field_type=_FIELD), 'context': Field(name='context',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The environment context to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'environment': Field(name='environment',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The environment to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'illegal_action_logreward': Field(name='illegal_action_logreward',type=<class 'float'>,default=-50.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The illegal action log reward. '}),kw_only=False,_field_type=_FIELD), 'lr_decay': Field(name='lr_decay',type=<class 'float'>,default=20000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate decay steps. '}),kw_only=False,_field_type=_FIELD), 'max_nodes': Field(name='max_nodes',type=<class 'int'>,default=9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum number of nodes. '}),kw_only=False,_field_type=_FIELD), 'model': Field(name='model',type=<class 'str'>,default='graph_transformer_gfn',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The model to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'momentum': Field(name='momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The momentum. '}),kw_only=False,_field_type=_FIELD), 'num_emb': Field(name='num_emb',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of embeddings. '}),kw_only=False,_field_type=_FIELD), 'num_layers': Field(name='num_layers',type=<class 'int'>,default=4,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of layers. '}),kw_only=False,_field_type=_FIELD), 'num_offline': Field(name='num_offline',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of offline samples. '}),kw_only=False,_field_type=_FIELD), 'random_action_prob': Field(name='random_action_prob',type=<class 'float'>,default=0.001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The random action probability. '}),kw_only=False,_field_type=_FIELD), 'reward_loss_multiplier': Field(name='reward_loss_multiplier',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The reward loss multiplier. '}),kw_only=False,_field_type=_FIELD), 'sampling_model': Field(name='sampling_model',type=<class 'str'>,default='graph_transformer_gfn',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The model used to generate samples. '}),kw_only=False,_field_type=_FIELD), 'sampling_tau': Field(name='sampling_tau',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The sampling temperature. '}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=<class 'str'>,default='qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The task to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'tb_epsilon': Field(name='tb_epsilon',type=<class 'float'>,default=1e-10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The epsilon. '}),kw_only=False,_field_type=_FIELD), 'temperature_dist_params': Field(name='temperature_dist_params',type=<class 'str'>,default='(.5, 32)',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The temperature distribution parameters. '}),kw_only=False,_field_type=_FIELD), 'temperature_sample_dist': Field(name='temperature_sample_dist',type=<class 'str'>,default='uniform',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The temperature sample distribution. '}),kw_only=False,_field_type=_FIELD), 'weight_decay': Field(name='weight_decay',type=<class 'float'>,default=1e-08,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The weight decay. '}),kw_only=False,_field_type=_FIELD), 'z_lr_decay': Field(name='z_lr_decay',type=<class 'float'>,default=20000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate decay steps for z.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to model.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(algorithm='trajectory_balance', context=None, environment=None, model='graph_transformer_gfn', sampling_model='graph_transformer_gfn', task='qm9', bootstrap_own_reward=False, num_emb=128, num_layers=4, tb_epsilon=1e-10, illegal_action_logreward=-50.0, reward_loss_multiplier=1.0, temperature_sample_dist='uniform', temperature_dist_params='(.5, 32)', weight_decay=1e-08, momentum=0.9, adam_eps=1e-08, lr_decay=20000, z_lr_decay=20000, clip_grad_type='norm', clip_grad_param=10.0, random_action_prob=0.001, sampling_tau=0.0, max_nodes=9, num_offline=10)
__match_args__ = ('algorithm', 'context', 'environment', 'model', 'sampling_model', 'task', 'bootstrap_own_reward', 'num_emb', 'num_layers', 'tb_epsilon', 'illegal_action_logreward', 'reward_loss_multiplier', 'temperature_sample_dist', 'temperature_dist_params', 'weight_decay', 'momentum', 'adam_eps', 'lr_decay', 'z_lr_decay', 'clip_grad_type', 'clip_grad_param', 'random_action_prob', 'sampling_tau', 'max_nodes', 'num_offline')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'
__repr__()

Return repr(self).

class GFlowNetDataArguments(dataset='qm9', dataset_path='./data/qm9', batch_size=64, global_batch_size=16, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_data_loader_workers=8, sampling_iterator=True, ratio=0.9)[source]

Bases: TrainingPipelineArguments

Arguments related to data.

__name__ = 'GFlowNetDataArguments'
dataset: str = 'qm9'
dataset_path: str = './data/qm9'
batch_size: int = 64
global_batch_size: int = 16
validation_split: Optional[float] = None
validation_indices_file: Optional[str] = None
stratified_batch_file: Optional[str] = None
stratified_value_name: Optional[str] = None
num_data_loader_workers: int = 8
sampling_iterator: bool = True
ratio: float = 0.9
__annotations__ = {'batch_size': <class 'int'>, 'dataset': <class 'str'>, 'dataset_path': <class 'str'>, 'global_batch_size': <class 'int'>, 'num_data_loader_workers': <class 'int'>, 'ratio': <class 'float'>, 'sampling_iterator': <class 'bool'>, 'stratified_batch_file': typing.Optional[str], 'stratified_value_name': typing.Optional[str], 'validation_indices_file': typing.Optional[str], 'validation_split': typing.Optional[float]}
__dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size of the training. Defaults to 64.'}),kw_only=False,_field_type=_FIELD), 'dataset': Field(name='dataset',type=<class 'str'>,default='qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The dataset to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'dataset_path': Field(name='dataset_path',type=<class 'str'>,default='./data/qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The path to the dataset to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'global_batch_size': Field(name='global_batch_size',type=<class 'int'>,default=16,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Global batch size of the training. Defaults to 16.'}),kw_only=False,_field_type=_FIELD), 'num_data_loader_workers': Field(name='num_data_loader_workers',type=<class 'int'>,default=8,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of data loader workers. '}),kw_only=False,_field_type=_FIELD), 'ratio': Field(name='ratio',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The ratio. '}),kw_only=False,_field_type=_FIELD), 'sampling_iterator': Field(name='sampling_iterator',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use a sampling iterator. '}),kw_only=False,_field_type=_FIELD), 'stratified_batch_file': Field(name='stratified_batch_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified batch file for sampling. Defaults to None, a.k.a., no stratified sampling.'}),kw_only=False,_field_type=_FIELD), 'stratified_value_name': Field(name='stratified_value_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified value name. Defaults to None, a.k.a., no stratified sampling. Needed in case a stratified batch file is provided.'}),kw_only=False,_field_type=_FIELD), 'validation_indices_file': Field(name='validation_indices_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Indices to use for validation. Defaults to None, a.k.a., use validation split proportion, if not provided uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD), 'validation_split': Field(name='validation_split',type=typing.Optional[float],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Proportion used for validation. Defaults to None, a.k.a., use indices file if provided otherwise uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to data.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(dataset='qm9', dataset_path='./data/qm9', batch_size=64, global_batch_size=16, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_data_loader_workers=8, sampling_iterator=True, ratio=0.9)
__match_args__ = ('dataset', 'dataset_path', 'batch_size', 'global_batch_size', 'validation_split', 'validation_indices_file', 'stratified_batch_file', 'stratified_value_name', 'num_data_loader_workers', 'sampling_iterator', 'ratio')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'
__repr__()

Return repr(self).

class GFlowNetSavingArguments(model_path)[source]

Bases: TrainingPipelineArguments

Saving arguments related to Granular trainer.

__name__ = 'GFlowNetSavingArguments'
model_path: str
__annotations__ = {'model_path': <class 'str'>}
__dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the checkpoint file to be used.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to Granular trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path)
__match_args__ = ('model_path',)
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'
__repr__()

Return repr(self).