gt4sd.frameworks.gflownet.loss.trajectory_balance module

Summary

Classes:

TrajectoryBalance

A trajectory balance algorithm for gflownet.

TrajectoryBalanceModel

Generic model compatible with trajectory balance.

Reference

class TrajectoryBalanceModel[source]

Bases: Module

Generic model compatible with trajectory balance.

forward(batch)[source]

Run forward pass.

Parameters

batch (Batch) – batch of data

Returns

action to take

Return type

action

log_z(cond_info)[source]

Compute log_z.

Parameters

cond_info (Tensor) – conditional information

Return type

Tensor

Returns

log partition function.

__annotations__ = {}
__doc__ = 'Generic model compatible with trajectory balance.'
__module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'
class TrajectoryBalance(configuration, environment, context, max_len=None)[source]

Bases: object

A trajectory balance algorithm for gflownet.

__init__(configuration, environment, context, max_len=None)[source]

Initialize trajectory balance algorithm.

Trajectory balance implementation, see “Trajectory Balance: Improved Credit Assignment in GFlowNets

Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio” https://arxiv.org/abs/2201.13259.

Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/algo/trajectory_balance.py.

Args

configuration: hyperparameters. environment: a graph environment. context: a context. rng: rng used to take random actions. hps: hyperparameter dictionary, see above for used keys.

  • random_action_prob: float, probability of taking a uniform random action when sampling.

  • illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions.

  • bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data.

  • tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio.

  • reward_loss_multiplier: float, multiplying constant for the bootstrap loss.

max_len: if not None, ends trajectories of more than max_len steps. max_nodes: if not None, ends trajectories of graphs with more than max_nodes steps (illegal action).

_corrupt_actions(actions, cat)[source]

Sample from the uniform policy with probability random_action_prob.

Parameters
  • actions (List[Tuple[int, int, int]]) – list of actions.

  • cat (GraphActionCategorical) – action categorical.

create_training_data_from_own_samples(model, n, cond_info)[source]

Generate trajectories by sampling a model.

Parameters
  • model (Union[Module, TrajectoryBalanceModel]) – model used with a certain algorithm (i.e. trajectory balance). The model being sampled.

  • graphs – list of N Graph endpoints.

  • cond_info (Tensor) – conditional information, shape (N, n_info).

Returns

a list of trajectories. Each trajectory is a dict with keys.
  • trajs: List[Tuple[Graph, GraphAction]].

  • reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise.

  • fwd_logprob: log_z + sum logprobs P_F.

  • bck_logprob: sum logprobs P_B.

  • log_z: predicted log_z.

  • loss: predicted loss (if bootstrapping).

  • is_valid: is the generated graph valid according to the environment and context.

Return type

data

create_training_data_from_graphs(graphs)[source]

Generate trajectories from known endpoints.

Parameters

graphs (List[Graph]) – list of Graph endpoints.

Return type

List[Dict[str, List[Tuple[Graph, GraphAction]]]]

Returns

a list of trajectories.

construct_batch(trajs, cond_info, rewards)[source]

Construct a batch from a list of trajectories and their information.

Parameters
  • trajs (List[Dict[str, List[Tuple[Graph, GraphAction]]]]) – a list of N trajectories.

  • cond_info (float) – the conditional info that is considered for each trajectory. Shape (N, n_info).

  • rewards (float) – the transformed reward (e.g. R(x) ** beta) for each trajectory. Shape (N,).

Return type

Batch

Returns

a (CPU) Batch object with relevant attributes added.

compute_batch_losses(model, batch, num_bootstrap=0)[source]

Compute the losses over trajectories contained in the batch.

Parameters
  • model (TrajectoryBalanceModel) – a GNN taking in a batch of graphs as input as per constructed by self.construct_batch. Must have a log_z attribute, itself a model, which predicts log of z(cond_info)

  • batch (Batch) – batch of graphs inputs as per constructed by self.construct_batch.

  • num_bootstrap (int) – the number of trajectories for which the reward loss is computed. Ignored if 0.

Return type

Tuple[Tensor, Dict]

Returns

a tuple containing the loss for each trajectory and relevant info.

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.loss.trajectory_balance', '__doc__': 'A trajectory balance algorithm for gflownet.', '__init__': <function TrajectoryBalance.__init__>, '_corrupt_actions': <function TrajectoryBalance._corrupt_actions>, 'create_training_data_from_own_samples': <function TrajectoryBalance.create_training_data_from_own_samples>, 'create_training_data_from_graphs': <function TrajectoryBalance.create_training_data_from_graphs>, 'construct_batch': <function TrajectoryBalance.construct_batch>, 'compute_batch_losses': <function TrajectoryBalance.compute_batch_losses>, '__dict__': <attribute '__dict__' of 'TrajectoryBalance' objects>, '__weakref__': <attribute '__weakref__' of 'TrajectoryBalance' objects>, '__annotations__': {}})
__doc__ = 'A trajectory balance algorithm for gflownet.'
__module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'
__weakref__

list of weak references to the object (if defined)