gt4sd.frameworks.gflownet.loss.trajectory_balance module¶
Summary¶
Classes:
A trajectory balance algorithm for gflownet.  | 
|
Generic model compatible with trajectory balance.  | 
Reference¶
- class TrajectoryBalanceModel[source]¶
 Bases:
ModuleGeneric model compatible with trajectory balance.
- forward(batch)[source]¶
 Run forward pass.
- Parameters
 batch (
Batch) – batch of data- Returns
 action to take
- Return type
 action
- log_z(cond_info)[source]¶
 Compute log_z.
- Parameters
 cond_info (
Tensor) – conditional information- Return type
 Tensor- Returns
 log partition function.
- __annotations__ = {}¶
 
- __doc__ = 'Generic model compatible with trajectory balance.'¶
 
- __module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'¶
 
- class TrajectoryBalance(configuration, environment, context, max_len=None)[source]¶
 Bases:
objectA trajectory balance algorithm for gflownet.
- __init__(configuration, environment, context, max_len=None)[source]¶
 Initialize trajectory balance algorithm.
Trajectory balance implementation, see “Trajectory Balance: Improved Credit Assignment in GFlowNets
Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio” https://arxiv.org/abs/2201.13259.
Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/algo/trajectory_balance.py.
- Args
 configuration: hyperparameters. environment: a graph environment. context: a context. rng: rng used to take random actions. hps: hyperparameter dictionary, see above for used keys.
random_action_prob: float, probability of taking a uniform random action when sampling.
illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions.
bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data.
tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio.
reward_loss_multiplier: float, multiplying constant for the bootstrap loss.
max_len: if not None, ends trajectories of more than max_len steps. max_nodes: if not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
- _corrupt_actions(actions, cat)[source]¶
 Sample from the uniform policy with probability random_action_prob.
- Parameters
 actions (
List[Tuple[int,int,int]]) – list of actions.cat (
GraphActionCategorical) – action categorical.
- create_training_data_from_own_samples(model, n, cond_info)[source]¶
 Generate trajectories by sampling a model.
- Parameters
 model (
Union[Module,TrajectoryBalanceModel]) – model used with a certain algorithm (i.e. trajectory balance). The model being sampled.graphs – list of N Graph endpoints.
cond_info (
Tensor) – conditional information, shape (N, n_info).
- Returns
 - a list of trajectories. Each trajectory is a dict with keys.
 trajs: List[Tuple[Graph, GraphAction]].
reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise.
fwd_logprob: log_z + sum logprobs P_F.
bck_logprob: sum logprobs P_B.
log_z: predicted log_z.
loss: predicted loss (if bootstrapping).
is_valid: is the generated graph valid according to the environment and context.
- Return type
 data
- create_training_data_from_graphs(graphs)[source]¶
 Generate trajectories from known endpoints.
- Parameters
 graphs (
List[Graph]) – list of Graph endpoints.- Return type
 List[Dict[str,List[Tuple[Graph,GraphAction]]]]- Returns
 a list of trajectories.
- construct_batch(trajs, cond_info, rewards)[source]¶
 Construct a batch from a list of trajectories and their information.
- Parameters
 trajs (
List[Dict[str,List[Tuple[Graph,GraphAction]]]]) – a list of N trajectories.cond_info (
float) – the conditional info that is considered for each trajectory. Shape (N, n_info).rewards (
float) – the transformed reward (e.g. R(x) ** beta) for each trajectory. Shape (N,).
- Return type
 Batch- Returns
 a (CPU) Batch object with relevant attributes added.
- compute_batch_losses(model, batch, num_bootstrap=0)[source]¶
 Compute the losses over trajectories contained in the batch.
- Parameters
 model (
TrajectoryBalanceModel) – a GNN taking in a batch of graphs as input as per constructed by self.construct_batch. Must have a log_z attribute, itself a model, which predicts log of z(cond_info)batch (
Batch) – batch of graphs inputs as per constructed by self.construct_batch.num_bootstrap (
int) – the number of trajectories for which the reward loss is computed. Ignored if 0.
- Return type
 Tuple[Tensor,Dict]- Returns
 a tuple containing the loss for each trajectory and relevant info.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.loss.trajectory_balance', '__doc__': 'A trajectory balance algorithm for gflownet.', '__init__': <function TrajectoryBalance.__init__>, '_corrupt_actions': <function TrajectoryBalance._corrupt_actions>, 'create_training_data_from_own_samples': <function TrajectoryBalance.create_training_data_from_own_samples>, 'create_training_data_from_graphs': <function TrajectoryBalance.create_training_data_from_graphs>, 'construct_batch': <function TrajectoryBalance.construct_batch>, 'compute_batch_losses': <function TrajectoryBalance.compute_batch_losses>, '__dict__': <attribute '__dict__' of 'TrajectoryBalance' objects>, '__weakref__': <attribute '__weakref__' of 'TrajectoryBalance' objects>, '__annotations__': {}})¶
 
- __doc__ = 'A trajectory balance algorithm for gflownet.'¶
 
- __module__ = 'gt4sd.frameworks.gflownet.loss.trajectory_balance'¶
 
- __weakref__¶
 list of weak references to the object (if defined)