gt4sd.frameworks.gflownet.envs.graph_building_env module

Summary

Classes:

Graph

Subclassing nx.Graph for debugging purposes.

GraphAction

Actions on a graph environment for gflownet.

GraphActionCategorical

A multi-type categorical action class compatible with generating structured actions on a graph.

GraphActionType

Action types for graph building environments.

GraphBuildingEnv

A Graph building environment which induces a dag state space compatible with gflownet.

GraphBuildingEnvContext

A context environment that defines what the graphs are, how they map to and from data.

Functions:

generate_forward_trajectory

Sample (uniformly) a trajectory that generates g.

graph_without_edge

Build a new graph without an edge.

graph_without_edge_attr

Build a new graph without an edge attribute.

graph_without_node

Build a new graph without a node.

graph_without_node_attr

Build a new graph without a node attribute.

Reference

class Graph(incoming_graph_data=None, **attr)[source]

Bases: Graph

Subclassing nx.Graph for debugging purposes.

__str__()[source]

Returns a short summary of the graph.

Returns

info – Graph information including the graph name (if any), graph type, and the number of nodes and edges.

Return type

string

Examples

>>> G = nx.Graph(name="foo")
>>> str(G)
"Graph named 'foo' with 0 nodes and 0 edges"
>>> G = nx.path_graph(3)
>>> str(G)
'Graph with 3 nodes and 2 edges'
__repr__()[source]

Return repr(self).

__doc__ = 'Subclassing nx.Graph for debugging purposes.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
graph_without_edge(g, e)[source]

Build a new graph without an edge.

Parameters
  • g (Graph) – a graph.

  • e (Graph) – an edge to remove.

Return type

Graph

Returns

a new graph without an edge.

graph_without_node(g, n)[source]

Build a new graph without a node.

Parameters
  • g (Graph) – a graph.

  • n (Graph) – a node to remove.

Return type

Graph

Returns

a new graph without a node.

graph_without_node_attr(g, n, a)[source]

Build a new graph without a node attribute.

Parameters
  • g (Graph) – a graph.

  • n (Graph) – a node.

  • a (Any) – a node attribute.

Return type

Graph

Returns

a new graph without a node attribute.

graph_without_edge_attr(g, e, a)[source]

Build a new graph without an edge attribute.

Parameters
  • g (Graph) – a graph.

  • e (Graph) – an edge.

  • a (Any) – an edge attribute.

Return type

Graph

Returns

a new graph without an edge attribute.

class GraphActionType(value)[source]

Bases: Enum

Action types for graph building environments.

Stop = 1
AddNode = 2
AddEdge = 3
SetNodeAttr = 4
SetEdgeAttr = 5
RemoveNode = 6
RemoveEdge = 7
RemoveNodeAttr = 8
RemoveEdgeAttr = 9
__doc__ = 'Action types for graph building environments.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
class GraphAction(action, source=None, target=None, value=None, attr=None, relabel=None)[source]

Bases: object

Actions on a graph environment for gflownet.

__init__(action, source=None, target=None, value=None, attr=None, relabel=None)[source]

Initialize a single graph-building action.

Parameters
  • action (GraphActionType) – the action type.

  • source (Optional[int, None]) – the source node this action is applied on.

  • target (Optional[int, None]) – the target node (i.e. if specified this is an edge action).

  • value (Optional[Any, None]) – the value (e.g. new node type) applied.

  • attr (Optional[str, None]) – the set attribute of a node/edge.

  • relabel (Optional[int, None]) – for AddNode actions, relabels the new node with that id.

__repr__()[source]

Return repr(self).

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'Actions on a graph environment for gflownet.', '__init__': <function GraphAction.__init__>, '__repr__': <function GraphAction.__repr__>, '__dict__': <attribute '__dict__' of 'GraphAction' objects>, '__weakref__': <attribute '__weakref__' of 'GraphAction' objects>, '__annotations__': {}})
__doc__ = 'Actions on a graph environment for gflownet.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
__weakref__

list of weak references to the object (if defined)

class GraphBuildingEnv(allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True)[source]

Bases: object

A Graph building environment which induces a dag state space compatible with gflownet.

__init__(allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True)[source]

Initialize a GraphBuildingEnv.

Supports forward and backward actions, with a parents function that list parents of forward actions. Edges and nodes can have attributes added to them in a key:value style. Edges and nodes are created with _implicit_ default attribute values (e.g. chirality, single/double bondness) so that:

  • an agent gets to do an extra action to set that attribute, but only

if it is still default-valued (DAG property preserved) - we can generate a legal action for any attribute that isn’t a default one.

Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/envs

Parameters
  • allow_add_edge (bool) – if True, allows this action and computes AddEdge parents (i.e. if False, this env only allows for tree generation).

  • allow_node_attr (bool) – if True, allows this action and computes SetNodeAttr parents.

  • allow_edge_attr (bool) – if True, allows this action and computes SetEdgeAttr parents.

new()[source]
Return type

Graph

step(g, action)[source]

Step forward the given graph state with an action

Parameters
  • g (Graph) – the graph to be modified.

  • action (GraphAction) – the action taken on the graph, indices must match.

Returns

the new graph.

Return type

gp

count_backward_transitions(g)[source]

Counts the number of parents of g without checking for isomorphisms.

Return type

int

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'A Graph building environment which induces a dag state space compatible with gflownet.', '__init__': <function GraphBuildingEnv.__init__>, 'new': <function GraphBuildingEnv.new>, 'step': <function GraphBuildingEnv.step>, 'count_backward_transitions': <function GraphBuildingEnv.count_backward_transitions>, '__dict__': <attribute '__dict__' of 'GraphBuildingEnv' objects>, '__weakref__': <attribute '__weakref__' of 'GraphBuildingEnv' objects>, '__annotations__': {}})
__doc__ = 'A Graph building environment which induces a dag state space compatible with gflownet.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
__weakref__

list of weak references to the object (if defined)

generate_forward_trajectory(g, max_nodes=None)[source]

Sample (uniformly) a trajectory that generates g.

Parameters
  • g (Graph) – the graph to be generated.

  • max_nodes (Optional[int, None]) – the maximum number of nodes to generate.

Return type

List[Tuple[Graph, GraphAction]]

Returns

a list of (graph, action) pairs.

class GraphActionCategorical(graphs, logits, keys, types, deduplicate_edge_index=True)[source]

Bases: object

A multi-type categorical action class compatible with generating structured actions on a graph.

__init__(graphs, logits, keys, types, deduplicate_edge_index=True)[source]

Initialize a GraphActionCategorical.

What is meant by multi-type here is that there are multiple types of mutually exclusive actions, e.g. AddNode and AddEdge are mutually exclusive, but since their logits will be produced by different variable-sized tensors (corresponding to different elements of the graph, e.g. nodes or edges) it is inconvient to stack them all into one single Categorical. This class provides this convenient interaction between torch_geometric Batch objects and lists of logit tensors.

Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/envs

Parameters
  • graphs (Batch) – a Batch of graphs to which the logits correspond.

  • logits (List[Tensor]) – a list of tensors of shape (n, m) representing logits over a variable number of graph elements (e.g. nodes) for which there are m possible actions. n should thus be equal to the sum of the number of such elements for each graph in the Batch object. The length of the logits list should thus be equal to the number of element types (in other words there should be one tensor per type).

  • keys (List[Optional[str, None]]) – the keys corresponding to the Graph elements for each tensor in the logits list. Used to extract the _batch and slice attributes. For example, if the first logit tensor is a per-node action logit, and the second is a per-edge, keys could be [‘x’, ‘edge_index’]. If keys[i] is None, the corresponding logits are assumed to be graph-level (i.e. if there are k graphs in the Batch object, this logit tensor would have shape (k, m)).

  • types (List[GraphActionType]) – the action type each logit corresponds to.

  • deduplicate_edge_index – if true, this means that the ‘edge_index’ keys have been reduced by e_i[::2] (presumably because the graphs are undirected).

detach()[source]

Detach the logits from the graph batch.

to(device)[source]

Move everything to the specified device.

Parameters

device – the device to move to.

logsoftmax()[source]

Compute log-probabilities given logits.

Return type

List[Tensor]

Returns

a list of log-probabilities.

sample()[source]

Use the Gumbel trick to sample categoricals.

if X ~ argmax(logits - log(-log(uniform(logits.shape)))) then p(X = i) = exp(logits[i]) / Z Here we have to do the argmax first over the variable number of rows of each element type for each graph in the minibatch, then over the different types (since they are mutually exclusive).

Return type

List[Tuple[int, int, int]]

Returns

A list of tuples specifying the action (type, row, column).

log_prob(actions)[source]

The log-probability of a list of action tuples.

Parameters

actions (List[Tuple[int, int, int]]) – A list of action tuples (type, row, column).

Return type

Tensor

Returns

A tensor of shape (minibatch_size,) containing the log-probability of each action.

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'A multi-type categorical action class compatible with generating structured actions on a graph.', '__init__': <function GraphActionCategorical.__init__>, 'detach': <function GraphActionCategorical.detach>, 'to': <function GraphActionCategorical.to>, 'logsoftmax': <function GraphActionCategorical.logsoftmax>, 'sample': <function GraphActionCategorical.sample>, 'log_prob': <function GraphActionCategorical.log_prob>, '__dict__': <attribute '__dict__' of 'GraphActionCategorical' objects>, '__weakref__': <attribute '__weakref__' of 'GraphActionCategorical' objects>, '__annotations__': {'logprobs': 'Union[List[Any], None]'}})
__doc__ = 'A multi-type categorical action class compatible with generating structured actions on a graph.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
__weakref__

list of weak references to the object (if defined)

class GraphBuildingEnvContext[source]

Bases: object

A context environment that defines what the graphs are, how they map to and from data.

device: str
num_node_dim: int
num_edge_dim: int
num_cond_dim: int
num_new_node_values: int
num_node_attr_logits: int
num_edge_attr_logits: int
action_type_order: List[GraphActionType]
aidx_to_graph_action(g, action_idx)[source]

Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction.

Parameters
  • g (Data) – the graph to which the action is being applied.

  • action_idx (Tuple[int, int, int]) – the tensor indices for the corresponding action.

Return type

GraphAction

Returns

a graph action that could be applied to the original graph coressponding to g.

graph_action_to_aidx(g, action)[source]

Translate a GraphAction to an action index (e.g. from a GraphActionCategorical).

Parameters
  • g (Data) – the graph to which the action is being applied.

  • action (GraphAction) – a graph action that could be applied to the original graph coressponding to g.

Return type

Tuple[int, int, int]

Returns

the tensor indices for the corresponding action.

graph_to_mol(g)[source]

Translate a graph to a molecule.

Parameters

g (Graph) – the graph to translate.

Return type

Mol

Returns

the molecule corresponding to the graph.

sample_conditional_information()[source]

Sample conditional information.

graph_to_data(g)[source]

Convert a networkx Graph to a torch geometric Data instance.

Parameters

g (Graph) – a graph instance.

Return type

Data

Returns

the corresponding torch_geometric graph data.

collate(graphs)[source]

Convert a list of torch geometric Data instances to a Batch instance. This exists so that environment contexts can set custom batching attributes, e.g. by using follow_batch.

Parameters

graphs (List[Data]) – graph instances.

Return type

Batch

Returns

the corresponding torch_geometric batch.

is_sane(g)[source]

Verifies whether a graph is sane according to the context. This can catch, e.g. impossible molecules.

Parameters

g (Graph) – a graph.

Return type

bool

Returns

true if the environment considers g to be sane.

mol_to_graph(mol)[source]

Verifies whether a graph is sane according to the context. This can catch, e.g. impossible molecules.

Parameters

mol (Mol) – an RDKit molecule.

Return type

Graph

Returns

the corresponding Graph representation of that molecule.

__annotations__ = {'action_type_order': typing.List[gt4sd.frameworks.gflownet.envs.graph_building_env.GraphActionType], 'device': <class 'str'>, 'num_cond_dim': <class 'int'>, 'num_edge_attr_logits': <class 'int'>, 'num_edge_dim': <class 'int'>, 'num_new_node_values': <class 'int'>, 'num_node_attr_logits': <class 'int'>, 'num_node_dim': <class 'int'>}
__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__annotations__': {'device': <class 'str'>, 'num_node_dim': <class 'int'>, 'num_edge_dim': <class 'int'>, 'num_cond_dim': <class 'int'>, 'num_new_node_values': <class 'int'>, 'num_node_attr_logits': <class 'int'>, 'num_edge_attr_logits': <class 'int'>, 'action_type_order': typing.List[gt4sd.frameworks.gflownet.envs.graph_building_env.GraphActionType]}, '__doc__': 'A context environment that defines what the graphs are, how they map to and from data.', 'aidx_to_graph_action': <function GraphBuildingEnvContext.aidx_to_graph_action>, 'graph_action_to_aidx': <function GraphBuildingEnvContext.graph_action_to_aidx>, 'graph_to_mol': <function GraphBuildingEnvContext.graph_to_mol>, 'sample_conditional_information': <function GraphBuildingEnvContext.sample_conditional_information>, 'graph_to_data': <function GraphBuildingEnvContext.graph_to_data>, 'collate': <function GraphBuildingEnvContext.collate>, 'is_sane': <function GraphBuildingEnvContext.is_sane>, 'mol_to_graph': <function GraphBuildingEnvContext.mol_to_graph>, '__dict__': <attribute '__dict__' of 'GraphBuildingEnvContext' objects>, '__weakref__': <attribute '__weakref__' of 'GraphBuildingEnvContext' objects>})
__doc__ = 'A context environment that defines what the graphs are, how they map to and from data.'
__module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'
__weakref__

list of weak references to the object (if defined)