#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import copy
from itertools import count
from typing import Any, Dict, List, Tuple, Union, Optional
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor
from torch_scatter import scatter
from ..envs.graph_building_env import (
Graph,
GraphAction,
GraphActionCategorical,
GraphActionType,
GraphBuildingEnv,
GraphBuildingEnvContext,
generate_forward_trajectory,
)
[docs]class TrajectoryBalanceModel(nn.Module):
"""Generic model compatible with trajectory balance."""
[docs] def forward(self, batch: gd.Batch) -> Tuple[GraphActionCategorical, Tensor]:
"""Run forward pass.
Args:
batch: batch of data
Returns:
action: action to take
"""
raise NotImplementedError()
[docs] def log_z(self, cond_info: Tensor) -> Tensor:
"""Compute log_z.
Args:
cond_info: conditional information
Returns:
log partition function.
"""
raise NotImplementedError()
[docs]class TrajectoryBalance:
"""A trajectory balance algorithm for gflownet."""
[docs] def __init__(
self,
configuration: Dict[str, Any],
environment: GraphBuildingEnv,
context: GraphBuildingEnvContext,
max_len: Optional[int] = None,
):
"""Initialize trajectory balance algorithm.
Trajectory balance implementation, see
"Trajectory Balance: Improved Credit Assignment in GFlowNets
Nikolay Malkin, Moksh Jain, Emmanuel Bengio, Chen Sun, Yoshua Bengio"
https://arxiv.org/abs/2201.13259.
Code adapted from: https://github.com/recursionpharma/gflownet/blob/trunk/src/gflownet/algo/trajectory_balance.py.
Args
configuration: hyperparameters.
environment: a graph environment.
context: a context.
rng: rng used to take random actions.
hps: hyperparameter dictionary, see above for used keys.
- random_action_prob: float, probability of taking a uniform random action when sampling.
- illegal_action_logreward: float, log(R) given to the model for non-sane end states or illegal actions.
- bootstrap_own_reward: bool, if True, uses the .reward batch data to predict rewards for sampled data.
- tb_epsilon: float, if not None, adds this epsilon in the numerator and denominator of the log-ratio.
- reward_loss_multiplier: float, multiplying constant for the bootstrap loss.
max_len: if not None, ends trajectories of more than max_len steps.
max_nodes: if not None, ends trajectories of graphs with more than max_nodes steps (illegal action).
"""
self.ctx = context
self.env = environment
self.hps = configuration
self.rng = self.hps["rng"]
self.max_len = max_len
self.max_nodes = self.hps["max_nodes"]
self.random_action_prob = self.hps["random_action_prob"]
self.illegal_action_logreward = self.hps["illegal_action_logreward"]
self.bootstrap_own_reward = self.hps["bootstrap_own_reward"]
self.sanitize_samples = True
self.epsilon = self.hps["tb_epsilon"]
self.reward_loss_multiplier = self.hps["reward_loss_multiplier"]
# Experimental flags
self.reward_loss_is_mae = True
self.tb_loss_is_mae = False
self.tb_loss_is_huber = False
self.mask_invalid_rewards = False
self.length_normalize_losses = False
self.sample_temp = 1
[docs] def _corrupt_actions(
self, actions: List[Tuple[int, int, int]], cat: GraphActionCategorical
):
"""Sample from the uniform policy with probability random_action_prob.
Args:
actions: list of actions.
cat: action categorical.
"""
if self.random_action_prob <= 0:
return
(corrupted,) = (
self.rng.uniform(size=len(actions)) < self.random_action_prob
).nonzero()
for i in corrupted:
n_in_batch = [int((b == i).sum()) for b in cat.batch]
n_each = np.array(
[
float(logit.shape[1]) * nb
for logit, nb in zip(cat.logits, n_in_batch)
]
)
which = self.rng.choice(len(n_each), p=n_each / n_each.sum())
row = self.rng.choice(n_in_batch[which])
col = self.rng.choice(cat.logits[which].shape[1])
actions[i] = (which, row, col)
[docs] def create_training_data_from_own_samples(
self, model: Union[nn.Module, TrajectoryBalanceModel], n: int, cond_info: Tensor
) -> List[Dict]:
"""Generate trajectories by sampling a model.
Args:
model: model used with a certain algorithm (i.e. trajectory balance). The model being sampled.
graphs: list of N Graph endpoints.
cond_info: conditional information, shape (N, n_info).
Returns:
data: a list of trajectories. Each trajectory is a dict with keys.
- trajs: List[Tuple[Graph, GraphAction]].
- reward_pred: float, -100 if an illegal action is taken, predicted R(x) if bootstrapping, None otherwise.
- fwd_logprob: log_z + sum logprobs P_F.
- bck_logprob: sum logprobs P_B.
- log_z: predicted log_z.
- loss: predicted loss (if bootstrapping).
- is_valid: is the generated graph valid according to the environment and context.
"""
ctx = self.ctx
env = self.env
dev = self.ctx.device
cond_info = cond_info.to(dev)
# how do we compute log_z pred?
log_z_pred = model.log_z(cond_info) # type: ignore
# This will be returned as training data
data: List[Dict] = []
for i in range(n):
data.append({"traj": [], "reward_pred": None, "is_valid": True})
# Let's also keep track of trajectory statistics according to the model
zero = torch.tensor([0], device=dev).float()
fwd_logprob: List[List[Tensor]] = [[] for i in range(n)]
bck_logprob: List[List[Tensor]] = [
[zero] for i in range(n)
] # zero in case there is a single invalid action
graphs = [env.new() for i in range(n)]
done = [False] * n
def not_done(lst):
return [e for i, e in enumerate(lst) if not done[i]]
# TODO report these stats:
mol_too_big = 0
mol_not_sane = 0
invalid_act = 0
logprob_of_illegal = []
illegal_action_logreward = torch.tensor(
[self.illegal_action_logreward], device=dev
)
if self.epsilon is not None:
epsilon = torch.tensor([self.epsilon], device=dev).float()
for t in range(self.max_len) if self.max_len is not None else count(0):
# Construct graphs for the trajectories that aren't yet done
torch_graphs = [ctx.graph_to_data(i) for i in not_done(graphs)]
not_done_mask = torch.tensor(done, device=dev).logical_not()
# Forward pass to get GraphActionCategorical
fwd_cat, log_reward_preds = model(
ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]
)
if self.sample_temp != 1:
sample_cat = copy.copy(fwd_cat)
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
actions = sample_cat.sample()
else:
actions = fwd_cat.sample()
self._corrupt_actions(actions, fwd_cat)
graph_actions = [
ctx.aidx_to_graph_action(g, a) for g, a in zip(torch_graphs, actions)
]
log_probs = fwd_cat.log_prob(actions)
for i, j in zip(not_done(range(n)), range(n)):
# Step each trajectory, and accumulate statistics
fwd_logprob[i].append(log_probs[j].unsqueeze(0))
data[i]["traj"].append((graphs[i], graph_actions[j]))
# Check if we're done
if graph_actions[j].action is GraphActionType.Stop:
done[i] = True
if self.sanitize_samples and not ctx.is_sane(graphs[i]):
# check if the graph is sane (e.g. RDKit can
# construct a molecule from it) otherwise
# treat the done action as illegal
mol_not_sane += 1
data[i]["reward_pred"] = illegal_action_logreward.exp()
data[i]["is_valid"] = False
elif self.bootstrap_own_reward:
# if we're bootstrapping, extract reward prediction
data[i]["reward_pred"] = log_reward_preds[j].detach().exp()
else: # If not done, try to step the environment
gp = graphs[i]
try:
# env.step can raise AssertionError if the action is illegal
gp = env.step(graphs[i], graph_actions[j])
if self.max_nodes is not None:
assert len(gp.nodes) <= self.max_nodes
except AssertionError:
if len(gp.nodes) > self.max_nodes:
mol_too_big += 1
else:
invalid_act += 1
done[i] = True
data[i]["reward_pred"] = illegal_action_logreward.exp()
data[i]["is_valid"] = False
continue
# Add to the trajectory
# P_B = uniform backward
n_back = env.count_backward_transitions(gp)
bck_logprob[i].append(torch.tensor([1 / n_back], device=dev).log())
graphs[i] = gp
if all(done):
break
for i in range(n):
# If we're not bootstrapping, we could query the reward
# model here, but this is expensive/impractical. Instead
# just report forward and backward flows
data[i]["log_z"] = log_z_pred[i].item()
data[i]["fwd_logprob"] = sum(fwd_logprob[i])
data[i]["bck_logprob"] = sum(bck_logprob[i])
if self.bootstrap_own_reward:
if not data[i]["is_valid"]:
logprob_of_illegal.append(data[i]["fwd_logprob"].item())
# If we are bootstrapping, we can report the theoretical loss as well
numerator = data[i]["fwd_logprob"] + log_z_pred[i]
denominator = data[i]["bck_logprob"] + data[i]["reward_pred"].log()
if self.epsilon is not None:
numerator = torch.logaddexp(numerator, epsilon)
denominator = torch.logaddexp(denominator, epsilon)
data[i]["loss"] = (numerator - denominator).pow(2)
return data
[docs] def create_training_data_from_graphs(
self, graphs: List[Graph]
) -> List[Dict[str, List[Tuple[Graph, GraphAction]]]]:
"""Generate trajectories from known endpoints.
Args:
graphs: list of Graph endpoints.
Returns:
a list of trajectories.
"""
return [{"traj": generate_forward_trajectory(i)} for i in graphs]
[docs] def construct_batch(
self,
trajs: List[Dict[str, List[Tuple[Graph, GraphAction]]]],
cond_info: float,
rewards: float,
) -> gd.Batch:
"""Construct a batch from a list of trajectories and their information.
Args:
trajs: a list of N trajectories.
cond_info: the conditional info that is considered for each trajectory. Shape (N, n_info).
rewards: the transformed reward (e.g. R(x) ** beta) for each trajectory. Shape (N,).
Returns:
a (CPU) Batch object with relevant attributes added.
"""
torch_graphs = [
self.ctx.graph_to_data(i[0]) for tj in trajs for i in tj["traj"]
]
actions = [
self.ctx.graph_action_to_aidx(g, a)
for g, a in zip(torch_graphs, [i[1] for tj in trajs for i in tj["traj"]])
]
num_backward = torch.tensor(
[
# Count the number of backward transitions from s_{t+1},
# unless t+1 = T is the last time step
self.env.count_backward_transitions(tj["traj"][i + 1][0])
if i + 1 < len(tj["traj"])
else 1
for tj in trajs
for i in range(len(tj["traj"]))
]
)
batch = self.ctx.collate(torch_graphs)
batch.traj_lens = torch.tensor([len(i["traj"]) for i in trajs])
batch.num_backward = num_backward
batch.actions = torch.tensor(actions)
batch.rewards = rewards
batch.cond_info = cond_info
batch.is_valid = torch.tensor([i.get("is_valid", True) for i in trajs]).float()
return batch
[docs] def compute_batch_losses(
self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0
) -> Tuple[torch.Tensor, Dict]:
"""Compute the losses over trajectories contained in the batch.
Args:
model: a GNN taking in a batch of graphs as input as per constructed by self.construct_batch.
Must have a `log_z` attribute, itself a model, which predicts log of z(cond_info)
batch: batch of graphs inputs as per constructed by self.construct_batch.
num_bootstrap: the number of trajectories for which the reward loss is computed. Ignored if 0.
Returns:
a tuple containing the loss for each trajectory and relevant info.
"""
dev = batch.x.device
# A single trajectory is comprised of many graphs
num_trajs = int(batch.traj_lens.shape[0])
rewards = batch.rewards
cond_info = batch.cond_info
# This index says which trajectory each graph belongs to, so
# it will look like [0,0,0,0,1,1,1,2,...] if trajectory 0 is
# of length 4, trajectory 1 of length 3, and so on.
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(
batch.traj_lens
)
# The position of the last graph of each trajectory
final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1
# Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions
fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx])
# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = log_reward_preds[final_graph_idx, 0]
# Compute trajectory balance objective
logz = model.log_z(cond_info)[:, 0]
# This is the log prob of each action in the trajectory
log_prob = fwd_cat.log_prob(batch.actions)
# The log prob of each backward action
log_p_B = (1 / batch.num_backward).log()
# Take log rewards, and clip
Rp = torch.maximum(rewards.log(), torch.tensor(-100.0, device=dev))
# This is the log probability of each trajectory
traj_log_prob = scatter(
log_prob, batch_idx, dim=0, dim_size=num_trajs, reduce="sum"
)
# Compute log numerator and denominator of the TB objective
numerator = logz + traj_log_prob
denominator = Rp + scatter(
log_p_B, batch_idx, dim=0, dim_size=num_trajs, reduce="sum"
)
if self.epsilon is not None:
# Numerical stability epsilon
epsilon = torch.tensor([self.epsilon], device=dev).float()
numerator = torch.logaddexp(numerator, epsilon)
denominator = torch.logaddexp(denominator, epsilon)
invalid_mask = 1 - batch.is_valid
if self.mask_invalid_rewards:
# Instead of being rude to the model and giving a
# logreward of -100 what if we say, whatever you think the
# logprobablity of this trajetcory is it should be smaller
# (thus the `numerator - 1`). Why 1? Intuition?
denominator = denominator * (1 - invalid_mask) + invalid_mask * (
numerator.detach() - 1
)
if self.tb_loss_is_mae:
traj_losses = abs(numerator - denominator)
elif self.tb_loss_is_huber:
pass # TODO
else:
traj_losses = (numerator - denominator).pow(2)
# Normalize losses by trajectory length
if self.length_normalize_losses:
traj_losses = traj_losses / batch.traj_lens
if self.bootstrap_own_reward:
num_bootstrap = num_bootstrap or len(rewards)
if self.reward_loss_is_mae:
reward_losses = abs(
rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap].exp()
)
else:
reward_losses = (
rewards[:num_bootstrap] - log_reward_preds[:num_bootstrap].exp()
).pow(2)
reward_loss = reward_losses.mean()
else:
reward_loss = 0
loss = traj_losses.mean() + reward_loss * self.reward_loss_multiplier
info = {
"offline_loss": traj_losses[: batch.num_offline].mean(),
"online_loss": traj_losses[batch.num_offline :].mean(),
"reward_loss": reward_loss,
"invalid_trajectories": invalid_mask.mean() * 2,
"invalid_logprob": (invalid_mask * traj_log_prob).sum()
/ (invalid_mask.sum() + 1e-4),
"invalid_losses": (invalid_mask * traj_losses).sum()
/ (invalid_mask.sum() + 1e-4),
"log_z": logz.mean(),
}
if not torch.isfinite(traj_losses).all():
raise ValueError("loss is not finite")
return loss, info