gt4sd.frameworks.gflownet.loss.trajectory_balance module¶
Summary¶
Classes:
A trajectory balance algorithm for gflownet. |
|
Generic model compatible with trajectory balance. |
Reference¶
- class TrajectoryBalanceModel[source]¶
Bases:
Module
Generic model compatible with trajectory balance.
- forward(batch)[source]¶
Run forward pass.
- Parameters
batch (
Batch
) – batch of data- Returns
action to take
- Return type
action
- log_z(cond_info)[source]¶
Compute log_z.
- Parameters
cond_info (
Tensor
) – conditional information- Return type
Tensor
- Returns
log partition function.
- __annotations__ = {}¶
- __doc__ = 'Generic model compatible with trajectory balance.'¶
- __module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'¶
- class TrajectoryBalance(configuration, environment, context, max_len=None)[source]¶
Bases:
object
A trajectory balance algorithm for gflownet.
- __init__(configuration, environment, context, max_len=None)[source]¶
Initialize trajectory balance algorithm.
Trajectory balance implementation, see “Trajectory Balance: Improved Credit Assignment in GFlowNets
Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio” https://arxiv.org/abs/2201.13259.
Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/algo/trajectory_balance.py.
- Args
configuration: hyperparameters. environment: a graph environment. context: a context. rng: rng used to take random actions. hps: hyperparameter dictionary, see above for used keys.
random_action_prob: float, probability of taking a uniform random action when sampling.
illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions.
bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data.
tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio.
reward_loss_multiplier: float, multiplying constant for the bootstrap loss.
max_len: if not None, ends trajectories of more than max_len steps. max_nodes: if not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
- _corrupt_actions(actions, cat)[source]¶
Sample from the uniform policy with probability random_action_prob.
- Parameters
actions (
List
[Tuple
[int
,int
,int
]]) – list of actions.cat (
GraphActionCategorical
) – action categorical.
- create_training_data_from_own_samples(model, n, cond_info)[source]¶
Generate trajectories by sampling a model.
- Parameters
model (
Union
[Module
,TrajectoryBalanceModel
]) – model used with a certain algorithm (i.e. trajectory balance). The model being sampled.graphs – list of N Graph endpoints.
cond_info (
Tensor
) – conditional information, shape (N, n_info).
- Returns
- a list of trajectories. Each trajectory is a dict with keys.
trajs: List[Tuple[Graph, GraphAction]].
reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise.
fwd_logprob: log_z + sum logprobs P_F.
bck_logprob: sum logprobs P_B.
log_z: predicted log_z.
loss: predicted loss (if bootstrapping).
is_valid: is the generated graph valid according to the environment and context.
- Return type
data
- create_training_data_from_graphs(graphs)[source]¶
Generate trajectories from known endpoints.
- Parameters
graphs (
List
[Graph
]) – list of Graph endpoints.- Return type
List
[Dict
[str
,List
[Tuple
[Graph
,GraphAction
]]]]- Returns
a list of trajectories.
- construct_batch(trajs, cond_info, rewards)[source]¶
Construct a batch from a list of trajectories and their information.
- Parameters
trajs (
List
[Dict
[str
,List
[Tuple
[Graph
,GraphAction
]]]]) – a list of N trajectories.cond_info (
float
) – the conditional info that is considered for each trajectory. Shape (N, n_info).rewards (
float
) – the transformed reward (e.g. R(x) ** beta) for each trajectory. Shape (N,).
- Return type
Batch
- Returns
a (CPU) Batch object with relevant attributes added.
- compute_batch_losses(model, batch, num_bootstrap=0)[source]¶
Compute the losses over trajectories contained in the batch.
- Parameters
model (
TrajectoryBalanceModel
) – a GNN taking in a batch of graphs as input as per constructed by self.construct_batch. Must have a log_z attribute, itself a model, which predicts log of z(cond_info)batch (
Batch
) – batch of graphs inputs as per constructed by self.construct_batch.num_bootstrap (
int
) – the number of trajectories for which the reward loss is computed. Ignored if 0.
- Return type
Tuple
[Tensor
,Dict
]- Returns
a tuple containing the loss for each trajectory and relevant info.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.loss.trajectory_balance', '__doc__': 'A trajectory balance algorithm for gflownet.', '__init__': <function TrajectoryBalance.__init__>, '_corrupt_actions': <function TrajectoryBalance._corrupt_actions>, 'create_training_data_from_own_samples': <function TrajectoryBalance.create_training_data_from_own_samples>, 'create_training_data_from_graphs': <function TrajectoryBalance.create_training_data_from_graphs>, 'construct_batch': <function TrajectoryBalance.construct_batch>, 'compute_batch_losses': <function TrajectoryBalance.compute_batch_losses>, '__dict__': <attribute '__dict__' of 'TrajectoryBalance' objects>, '__weakref__': <attribute '__weakref__' of 'TrajectoryBalance' objects>, '__annotations__': {}})¶
- __doc__ = 'A trajectory balance algorithm for gflownet.'¶
- __module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'¶
- __weakref__¶
list of weak references to the object (if defined)