gt4sd.training_pipelines.pytorch_lightning.gflownet.core module¶
GFlowNet training utilities.
Summary¶
Classes:
Arguments related to data. |
|
Arguments related to model. |
|
Arguments related to pytorch lightning trainer. |
|
Saving arguments related to Granular trainer. |
|
gflownet training pipelines. |
Reference¶
- class GFlowNetTrainingPipeline[source]¶
Bases:
PyTorchLightningTrainingPipeline
gflownet training pipelines.
- train(pl_trainer_args, model_args, dataset_args, dataset, environment, context, task)[source]¶
Generic training function for PyTorch Lightning-based training.
- Parameters
pl_trainer_args (
Dict
[str
,Any
]) – pytorch lightning trainer arguments passed to the configuration.model_args (
Dict
[str
,Union
[float
,str
,int
]]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Union
[float
,str
,int
]]) – dataset arguments passed to the configuration.dataset (
GFlowNetDataset
) – dataset to be used for training.environment (
GraphBuildingEnv
) – environment to be used for training.context (
GraphBuildingEnvContext
) – context to be used for training.task (
GFlowNetTask
) – task to be used for training.
- Return type
None
- get_data_and_model_modules(model_args, dataset_args, pl_trainer_args, dataset, environment, context, task)[source]¶
Get data and model modules for training.
- Parameters
model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
Tuple
[LightningDataModule
,LightningModule
]- Returns
the data and model modules.
- __doc__ = 'gflownet training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'¶
- class GFlowNetPytorchLightningTrainingArguments(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, trainer_log_every_n_steps=50, epochs=3, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple', learning_rate=0.0001, test_output_path='./test', num_workers=0, log_dir='./log/', num_training_steps=1000, validate_every=1000, seed=142857, device='cpu', development_mode=False)[source]¶
Bases:
PytorchLightningTrainingArguments
Arguments related to pytorch lightning trainer.
- __name__ = 'GFlowNetPytorchLightningTrainingArguments'¶
- strategy: Optional[str] = 'ddp'¶
- accumulate_grad_batches: int = 1¶
- trainer_log_every_n_steps: int = 50¶
- val_check_interval: int = 5000¶
- save_dir: Optional[str] = 'logs'¶
- basename: Optional[str] = 'lightning_logs'¶
- gradient_clip_val: float = 0.0¶
- limit_val_batches: int = 500¶
- log_every_n_steps: int = 500¶
- max_epochs: int = 3¶
- epochs: int = 3¶
- resume_from_checkpoint: Optional[str] = None¶
- gpus: Optional[int] = -1¶
- monitor: Optional[str] = None¶
- save_last: Optional[bool] = None¶
- save_top_k: int = 1¶
- mode: str = 'min'¶
- every_n_train_steps: Optional[int] = None¶
- check_val_every_n_epoch: Optional[int] = 5¶
- auto_lr_find: bool = True¶
- profiler: Optional[str] = 'simple'¶
- learning_rate: float = 0.0001¶
- num_workers: int = 0¶
- log_dir: str = './log/'¶
- num_training_steps: int = 1000¶
- validate_every: int = 1000¶
- seed: int = 142857¶
- device: str = 'cpu'¶
- development_mode: bool = False¶
- __annotations__ = {'accumulate_grad_batches': <class 'int'>, 'auto_lr_find': <class 'bool'>, 'basename': typing.Optional[str], 'check_val_every_n_epoch': typing.Optional[int], 'development_mode': <class 'bool'>, 'device': <class 'str'>, 'epochs': <class 'int'>, 'every_n_epochs': 'Optional[int]', 'every_n_train_steps': typing.Optional[int], 'gpus': typing.Optional[int], 'gradient_clip_val': <class 'float'>, 'learning_rate': <class 'float'>, 'limit_val_batches': <class 'int'>, 'log_dir': <class 'str'>, 'log_every_n_steps': <class 'int'>, 'max_epochs': <class 'int'>, 'mode': <class 'str'>, 'monitor': typing.Optional[str], 'num_training_steps': <class 'int'>, 'num_workers': <class 'int'>, 'profiler': typing.Optional[str], 'resume_from_checkpoint': typing.Optional[str], 'save_dir': typing.Optional[str], 'save_last': typing.Optional[bool], 'save_top_k': <class 'int'>, 'seed': <class 'int'>, 'strategy': typing.Optional[str], 'test_output_path': typing.Optional[str], 'trainer_log_every_n_steps': <class 'int'>, 'val_check_interval': <class 'int'>, 'validate_every': <class 'int'>}¶
- __dataclass_fields__ = {'accumulate_grad_batches': Field(name='accumulate_grad_batches',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Accumulates grads every k batches or as set up in the dict.'}),kw_only=False,_field_type=_FIELD), 'auto_lr_find': Field(name='auto_lr_find',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select whether to run a learning rate finder to try to optimize initial learning for faster convergence.'}),kw_only=False,_field_type=_FIELD), 'basename': Field(name='basename',type=typing.Optional[str],default='lightning_logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Experiment name.'}),kw_only=False,_field_type=_FIELD), 'check_val_every_n_epoch': Field(name='check_val_every_n_epoch',type=typing.Optional[int],default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of validation epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'development_mode': Field(name='development_mode',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to run in development mode. '}),kw_only=False,_field_type=_FIELD), 'device': Field(name='device',type=<class 'str'>,default='cpu',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The device to use.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'every_n_epochs': Field(name='every_n_epochs',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_train_steps': Field(name='every_n_train_steps',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training steps between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'gpus': Field(name='gpus',type=typing.Optional[int],default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of gpus to train on.'}),kw_only=False,_field_type=_FIELD), 'gradient_clip_val': Field(name='gradient_clip_val',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient clipping value.'}),kw_only=False,_field_type=_FIELD), 'learning_rate': Field(name='learning_rate',type=<class 'float'>,default=0.0001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate.'}),kw_only=False,_field_type=_FIELD), 'limit_val_batches': Field(name='limit_val_batches',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How much of validation dataset to check.'}),kw_only=False,_field_type=_FIELD), 'log_dir': Field(name='log_dir',type=<class 'str'>,default='./log/',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The directory to save logs.'}),kw_only=False,_field_type=_FIELD), 'log_every_n_steps': Field(name='log_every_n_steps',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to log within steps.'}),kw_only=False,_field_type=_FIELD), 'max_epochs': Field(name='max_epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'mode': Field(name='mode',type=<class 'str'>,default='min',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'monitor': Field(name='monitor',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'num_training_steps': Field(name='num_training_steps',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of training steps.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'number of workers. Defaults to 1.'}),kw_only=False,_field_type=_FIELD), 'profiler': Field(name='profiler',type=typing.Optional[str],default='simple',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'To profile individual steps during training and assist in identifying bottlenecks.'}),kw_only=False,_field_type=_FIELD), 'resume_from_checkpoint': Field(name='resume_from_checkpoint',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path/URL of the checkpoint from which training is resumed.'}),kw_only=False,_field_type=_FIELD), 'save_dir': Field(name='save_dir',type=typing.Optional[str],default='logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Save directory for logs and output.'}),kw_only=False,_field_type=_FIELD), 'save_last': Field(name='save_last',type=typing.Optional[bool],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'When True, always saves the model at the end of the epoch to a file last.ckpt'}),kw_only=False,_field_type=_FIELD), 'save_top_k': Field(name='save_top_k',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The best k models according to the quantity monitored will be saved.'}),kw_only=False,_field_type=_FIELD), 'seed': Field(name='seed',type=<class 'int'>,default=142857,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The random seed.'}),kw_only=False,_field_type=_FIELD), 'strategy': Field(name='strategy',type=typing.Optional[str],default='ddp',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training strategy.'}),kw_only=False,_field_type=_FIELD), 'test_output_path': Field(name='test_output_path',type=typing.Optional[str],default='./test',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where to save latent encodings and predictions for the test set when an epoch ends.'}),kw_only=False,_field_type=_FIELD), 'trainer_log_every_n_steps': Field(name='trainer_log_every_n_steps',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'log every k steps.'}),kw_only=False,_field_type=_FIELD), 'val_check_interval': Field(name='val_check_interval',type=<class 'int'>,default=5000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': ' How often to check the validation set.'}),kw_only=False,_field_type=_FIELD), 'validate_every': Field(name='validate_every',type=<class 'int'>,default=1000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of training steps between validation.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to pytorch lightning trainer.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, trainer_log_every_n_steps=50, epochs=3, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple', learning_rate=0.0001, test_output_path='./test', num_workers=0, log_dir='./log/', num_training_steps=1000, validate_every=1000, seed=142857, device='cpu', development_mode=False)¶
- __match_args__ = ('strategy', 'accumulate_grad_batches', 'val_check_interval', 'save_dir', 'basename', 'gradient_clip_val', 'limit_val_batches', 'log_every_n_steps', 'max_epochs', 'resume_from_checkpoint', 'gpus', 'monitor', 'save_last', 'save_top_k', 'mode', 'every_n_train_steps', 'every_n_epochs', 'trainer_log_every_n_steps', 'epochs', 'check_val_every_n_epoch', 'auto_lr_find', 'profiler', 'learning_rate', 'test_output_path', 'num_workers', 'log_dir', 'num_training_steps', 'validate_every', 'seed', 'device', 'development_mode')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'¶
- __repr__()¶
Return repr(self).
- class GFlowNetModelArguments(algorithm='trajectory_balance', context=None, environment=None, model='graph_transformer_gfn', sampling_model='graph_transformer_gfn', task='qm9', bootstrap_own_reward=False, num_emb=128, num_layers=4, tb_epsilon=1e-10, illegal_action_logreward=-50.0, reward_loss_multiplier=1.0, temperature_sample_dist='uniform', temperature_dist_params='(.5, 32)', weight_decay=1e-08, momentum=0.9, adam_eps=1e-08, lr_decay=20000, z_lr_decay=20000, clip_grad_type='norm', clip_grad_param=10.0, random_action_prob=0.001, sampling_tau=0.0, max_nodes=9, num_offline=10)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to model.
- __name__ = 'GFlowNetModelArguments'¶
- algorithm: str = 'trajectory_balance'¶
- context: str = None¶
- environment: str = None¶
- model: str = 'graph_transformer_gfn'¶
- sampling_model: str = 'graph_transformer_gfn'¶
- task: str = 'qm9'¶
- bootstrap_own_reward: bool = False¶
- num_emb: int = 128¶
- num_layers: int = 4¶
- tb_epsilon: float = 1e-10¶
- illegal_action_logreward: float = -50.0¶
- reward_loss_multiplier: float = 1.0¶
- temperature_sample_dist: str = 'uniform'¶
- temperature_dist_params: str = '(.5, 32)'¶
- weight_decay: float = 1e-08¶
- momentum: float = 0.9¶
- adam_eps: float = 1e-08¶
- lr_decay: float = 20000¶
- z_lr_decay: float = 20000¶
- clip_grad_type: str = 'norm'¶
- clip_grad_param: float = 10.0¶
- random_action_prob: float = 0.001¶
- sampling_tau: float = 0.0¶
- max_nodes: int = 9¶
- num_offline: int = 10¶
- __annotations__ = {'adam_eps': <class 'float'>, 'algorithm': <class 'str'>, 'bootstrap_own_reward': <class 'bool'>, 'clip_grad_param': <class 'float'>, 'clip_grad_type': <class 'str'>, 'context': <class 'str'>, 'environment': <class 'str'>, 'illegal_action_logreward': <class 'float'>, 'lr_decay': <class 'float'>, 'max_nodes': <class 'int'>, 'model': <class 'str'>, 'momentum': <class 'float'>, 'num_emb': <class 'int'>, 'num_layers': <class 'int'>, 'num_offline': <class 'int'>, 'random_action_prob': <class 'float'>, 'reward_loss_multiplier': <class 'float'>, 'sampling_model': <class 'str'>, 'sampling_tau': <class 'float'>, 'task': <class 'str'>, 'tb_epsilon': <class 'float'>, 'temperature_dist_params': <class 'str'>, 'temperature_sample_dist': <class 'str'>, 'weight_decay': <class 'float'>, 'z_lr_decay': <class 'float'>}¶
- __dataclass_fields__ = {'adam_eps': Field(name='adam_eps',type=<class 'float'>,default=1e-08,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The adam epsilon. '}),kw_only=False,_field_type=_FIELD), 'algorithm': Field(name='algorithm',type=<class 'str'>,default='trajectory_balance',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The algorithm to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'bootstrap_own_reward': Field(name='bootstrap_own_reward',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to bootstrap the own reward. '}),kw_only=False,_field_type=_FIELD), 'clip_grad_param': Field(name='clip_grad_param',type=<class 'float'>,default=10.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The clip grad param. '}),kw_only=False,_field_type=_FIELD), 'clip_grad_type': Field(name='clip_grad_type',type=<class 'str'>,default='norm',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The clip grad type. '}),kw_only=False,_field_type=_FIELD), 'context': Field(name='context',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The environment context to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'environment': Field(name='environment',type=<class 'str'>,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The environment to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'illegal_action_logreward': Field(name='illegal_action_logreward',type=<class 'float'>,default=-50.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The illegal action log reward. '}),kw_only=False,_field_type=_FIELD), 'lr_decay': Field(name='lr_decay',type=<class 'float'>,default=20000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate decay steps. '}),kw_only=False,_field_type=_FIELD), 'max_nodes': Field(name='max_nodes',type=<class 'int'>,default=9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The maximum number of nodes. '}),kw_only=False,_field_type=_FIELD), 'model': Field(name='model',type=<class 'str'>,default='graph_transformer_gfn',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The model to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'momentum': Field(name='momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The momentum. '}),kw_only=False,_field_type=_FIELD), 'num_emb': Field(name='num_emb',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of embeddings. '}),kw_only=False,_field_type=_FIELD), 'num_layers': Field(name='num_layers',type=<class 'int'>,default=4,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of layers. '}),kw_only=False,_field_type=_FIELD), 'num_offline': Field(name='num_offline',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of offline samples. '}),kw_only=False,_field_type=_FIELD), 'random_action_prob': Field(name='random_action_prob',type=<class 'float'>,default=0.001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The random action probability. '}),kw_only=False,_field_type=_FIELD), 'reward_loss_multiplier': Field(name='reward_loss_multiplier',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The reward loss multiplier. '}),kw_only=False,_field_type=_FIELD), 'sampling_model': Field(name='sampling_model',type=<class 'str'>,default='graph_transformer_gfn',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The model used to generate samples. '}),kw_only=False,_field_type=_FIELD), 'sampling_tau': Field(name='sampling_tau',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The sampling temperature. '}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=<class 'str'>,default='qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The task to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'tb_epsilon': Field(name='tb_epsilon',type=<class 'float'>,default=1e-10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The epsilon. '}),kw_only=False,_field_type=_FIELD), 'temperature_dist_params': Field(name='temperature_dist_params',type=<class 'str'>,default='(.5, 32)',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The temperature distribution parameters. '}),kw_only=False,_field_type=_FIELD), 'temperature_sample_dist': Field(name='temperature_sample_dist',type=<class 'str'>,default='uniform',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The temperature sample distribution. '}),kw_only=False,_field_type=_FIELD), 'weight_decay': Field(name='weight_decay',type=<class 'float'>,default=1e-08,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The weight decay. '}),kw_only=False,_field_type=_FIELD), 'z_lr_decay': Field(name='z_lr_decay',type=<class 'float'>,default=20000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate decay steps for z.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to model.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(algorithm='trajectory_balance', context=None, environment=None, model='graph_transformer_gfn', sampling_model='graph_transformer_gfn', task='qm9', bootstrap_own_reward=False, num_emb=128, num_layers=4, tb_epsilon=1e-10, illegal_action_logreward=-50.0, reward_loss_multiplier=1.0, temperature_sample_dist='uniform', temperature_dist_params='(.5, 32)', weight_decay=1e-08, momentum=0.9, adam_eps=1e-08, lr_decay=20000, z_lr_decay=20000, clip_grad_type='norm', clip_grad_param=10.0, random_action_prob=0.001, sampling_tau=0.0, max_nodes=9, num_offline=10)¶
- __match_args__ = ('algorithm', 'context', 'environment', 'model', 'sampling_model', 'task', 'bootstrap_own_reward', 'num_emb', 'num_layers', 'tb_epsilon', 'illegal_action_logreward', 'reward_loss_multiplier', 'temperature_sample_dist', 'temperature_dist_params', 'weight_decay', 'momentum', 'adam_eps', 'lr_decay', 'z_lr_decay', 'clip_grad_type', 'clip_grad_param', 'random_action_prob', 'sampling_tau', 'max_nodes', 'num_offline')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'¶
- __repr__()¶
Return repr(self).
- class GFlowNetDataArguments(dataset='qm9', dataset_path='./data/qm9', batch_size=64, global_batch_size=16, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_data_loader_workers=8, sampling_iterator=True, ratio=0.9)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to data.
- __name__ = 'GFlowNetDataArguments'¶
- dataset: str = 'qm9'¶
- dataset_path: str = './data/qm9'¶
- batch_size: int = 64¶
- global_batch_size: int = 16¶
- validation_split: Optional[float] = None¶
- validation_indices_file: Optional[str] = None¶
- stratified_batch_file: Optional[str] = None¶
- stratified_value_name: Optional[str] = None¶
- num_data_loader_workers: int = 8¶
- sampling_iterator: bool = True¶
- ratio: float = 0.9¶
- __annotations__ = {'batch_size': <class 'int'>, 'dataset': <class 'str'>, 'dataset_path': <class 'str'>, 'global_batch_size': <class 'int'>, 'num_data_loader_workers': <class 'int'>, 'ratio': <class 'float'>, 'sampling_iterator': <class 'bool'>, 'stratified_batch_file': typing.Optional[str], 'stratified_value_name': typing.Optional[str], 'validation_indices_file': typing.Optional[str], 'validation_split': typing.Optional[float]}¶
- __dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size of the training. Defaults to 64.'}),kw_only=False,_field_type=_FIELD), 'dataset': Field(name='dataset',type=<class 'str'>,default='qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The dataset to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'dataset_path': Field(name='dataset_path',type=<class 'str'>,default='./data/qm9',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The path to the dataset to use for training the model. '}),kw_only=False,_field_type=_FIELD), 'global_batch_size': Field(name='global_batch_size',type=<class 'int'>,default=16,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Global batch size of the training. Defaults to 16.'}),kw_only=False,_field_type=_FIELD), 'num_data_loader_workers': Field(name='num_data_loader_workers',type=<class 'int'>,default=8,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The number of data loader workers. '}),kw_only=False,_field_type=_FIELD), 'ratio': Field(name='ratio',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The ratio. '}),kw_only=False,_field_type=_FIELD), 'sampling_iterator': Field(name='sampling_iterator',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use a sampling iterator. '}),kw_only=False,_field_type=_FIELD), 'stratified_batch_file': Field(name='stratified_batch_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified batch file for sampling. Defaults to None, a.k.a., no stratified sampling.'}),kw_only=False,_field_type=_FIELD), 'stratified_value_name': Field(name='stratified_value_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified value name. Defaults to None, a.k.a., no stratified sampling. Needed in case a stratified batch file is provided.'}),kw_only=False,_field_type=_FIELD), 'validation_indices_file': Field(name='validation_indices_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Indices to use for validation. Defaults to None, a.k.a., use validation split proportion, if not provided uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD), 'validation_split': Field(name='validation_split',type=typing.Optional[float],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Proportion used for validation. Defaults to None, a.k.a., use indices file if provided otherwise uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to data.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(dataset='qm9', dataset_path='./data/qm9', batch_size=64, global_batch_size=16, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_data_loader_workers=8, sampling_iterator=True, ratio=0.9)¶
- __match_args__ = ('dataset', 'dataset_path', 'batch_size', 'global_batch_size', 'validation_split', 'validation_indices_file', 'stratified_batch_file', 'stratified_value_name', 'num_data_loader_workers', 'sampling_iterator', 'ratio')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'¶
- __repr__()¶
Return repr(self).
- class GFlowNetSavingArguments(model_path)[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to Granular trainer.
- __name__ = 'GFlowNetSavingArguments'¶
- model_path: str¶
- __annotations__ = {'model_path': <class 'str'>}¶
- __dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the checkpoint file to be used.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to Granular trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path)¶
- __match_args__ = ('model_path',)¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.gflownet.core'¶
- __repr__()¶
Return repr(self).