gt4sd.frameworks.gflownet.envs.graph_building_env module¶
Summary¶
Classes:
Subclassing nx.Graph for debugging purposes. |
|
Actions on a graph environment for gflownet. |
|
A multi-type categorical action class compatible with generating structured actions on a graph. |
|
Action types for graph building environments. |
|
A Graph building environment which induces a dag state space compatible with gflownet. |
|
A context environment that defines what the graphs are, how they map to and from data. |
Functions:
Sample (uniformly) a trajectory that generates g. |
|
Build a new graph without an edge. |
|
Build a new graph without an edge attribute. |
|
Build a new graph without a node. |
|
Build a new graph without a node attribute. |
Reference¶
- class Graph(incoming_graph_data=None, **attr)[source]¶
Bases:
Graph
Subclassing nx.Graph for debugging purposes.
- __str__()[source]¶
Returns a short summary of the graph.
- Returns
info – Graph information including the graph name (if any), graph type, and the number of nodes and edges.
- Return type
string
Examples
>>> G = nx.Graph(name="foo") >>> str(G) "Graph named 'foo' with 0 nodes and 0 edges"
>>> G = nx.path_graph(3) >>> str(G) 'Graph with 3 nodes and 2 edges'
- __doc__ = 'Subclassing nx.Graph for debugging purposes.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- graph_without_edge(g, e)[source]¶
Build a new graph without an edge.
- Parameters
g (
Graph
) – a graph.e (
Graph
) – an edge to remove.
- Return type
Graph
- Returns
a new graph without an edge.
- graph_without_node(g, n)[source]¶
Build a new graph without a node.
- Parameters
g (
Graph
) – a graph.n (
Graph
) – a node to remove.
- Return type
Graph
- Returns
a new graph without a node.
- graph_without_node_attr(g, n, a)[source]¶
Build a new graph without a node attribute.
- Parameters
g (
Graph
) – a graph.n (
Graph
) – a node.a (
Any
) – a node attribute.
- Return type
Graph
- Returns
a new graph without a node attribute.
- graph_without_edge_attr(g, e, a)[source]¶
Build a new graph without an edge attribute.
- Parameters
g (
Graph
) – a graph.e (
Graph
) – an edge.a (
Any
) – an edge attribute.
- Return type
Graph
- Returns
a new graph without an edge attribute.
- class GraphActionType(value)[source]¶
Bases:
Enum
Action types for graph building environments.
- Stop = 1¶
- AddNode = 2¶
- AddEdge = 3¶
- SetNodeAttr = 4¶
- SetEdgeAttr = 5¶
- RemoveNode = 6¶
- RemoveEdge = 7¶
- RemoveNodeAttr = 8¶
- RemoveEdgeAttr = 9¶
- __doc__ = 'Action types for graph building environments.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- class GraphAction(action, source=None, target=None, value=None, attr=None, relabel=None)[source]¶
Bases:
object
Actions on a graph environment for gflownet.
- __init__(action, source=None, target=None, value=None, attr=None, relabel=None)[source]¶
Initialize a single graph-building action.
- Parameters
action (
GraphActionType
) – the action type.source (
Optional
[int
,None
]) – the source node this action is applied on.target (
Optional
[int
,None
]) – the target node (i.e. if specified this is an edge action).value (
Optional
[Any
,None
]) – the value (e.g. new node type) applied.attr (
Optional
[str
,None
]) – the set attribute of a node/edge.relabel (
Optional
[int
,None
]) – for AddNode actions, relabels the new node with that id.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'Actions on a graph environment for gflownet.', '__init__': <function GraphAction.__init__>, '__repr__': <function GraphAction.__repr__>, '__dict__': <attribute '__dict__' of 'GraphAction' objects>, '__weakref__': <attribute '__weakref__' of 'GraphAction' objects>, '__annotations__': {}})¶
- __doc__ = 'Actions on a graph environment for gflownet.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- __weakref__¶
list of weak references to the object (if defined)
- class GraphBuildingEnv(allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True)[source]¶
Bases:
object
A Graph building environment which induces a dag state space compatible with gflownet.
- __init__(allow_add_edge=True, allow_node_attr=True, allow_edge_attr=True)[source]¶
Initialize a GraphBuildingEnv.
Supports forward and backward actions, with a parents function that list parents of forward actions. Edges and nodes can have attributes added to them in a key:value style. Edges and nodes are created with _implicit_ default attribute values (e.g. chirality, single/double bondness) so that:
an agent gets to do an extra action to set that attribute, but only
if it is still default-valued (DAG property preserved) - we can generate a legal action for any attribute that isn’t a default one.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/envs
- Parameters
allow_add_edge (
bool
) – if True, allows this action and computes AddEdge parents (i.e. if False, this env only allows for tree generation).allow_node_attr (
bool
) – if True, allows this action and computes SetNodeAttr parents.allow_edge_attr (
bool
) – if True, allows this action and computes SetEdgeAttr parents.
- step(g, action)[source]¶
Step forward the given graph state with an action
- Parameters
g (
Graph
) – the graph to be modified.action (
GraphAction
) – the action taken on the graph, indices must match.
- Returns
the new graph.
- Return type
gp
- count_backward_transitions(g)[source]¶
Counts the number of parents of g without checking for isomorphisms.
- Return type
int
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'A Graph building environment which induces a dag state space compatible with gflownet.', '__init__': <function GraphBuildingEnv.__init__>, 'new': <function GraphBuildingEnv.new>, 'step': <function GraphBuildingEnv.step>, 'count_backward_transitions': <function GraphBuildingEnv.count_backward_transitions>, '__dict__': <attribute '__dict__' of 'GraphBuildingEnv' objects>, '__weakref__': <attribute '__weakref__' of 'GraphBuildingEnv' objects>, '__annotations__': {}})¶
- __doc__ = 'A Graph building environment which induces a dag state space compatible with gflownet.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- __weakref__¶
list of weak references to the object (if defined)
- generate_forward_trajectory(g, max_nodes=None)[source]¶
Sample (uniformly) a trajectory that generates g.
- Parameters
g (
Graph
) – the graph to be generated.max_nodes (
Optional
[int
,None
]) – the maximum number of nodes to generate.
- Return type
List
[Tuple
[Graph
,GraphAction
]]- Returns
a list of (graph, action) pairs.
- class GraphActionCategorical(graphs, logits, keys, types, deduplicate_edge_index=True)[source]¶
Bases:
object
A multi-type categorical action class compatible with generating structured actions on a graph.
- __init__(graphs, logits, keys, types, deduplicate_edge_index=True)[source]¶
Initialize a GraphActionCategorical.
What is meant by multi-type here is that there are multiple types of mutually exclusive actions, e.g. AddNode and AddEdge are mutually exclusive, but since their logits will be produced by different variable-sized tensors (corresponding to different elements of the graph, e.g. nodes or edges) it is inconvient to stack them all into one single Categorical. This class provides this convenient interaction between torch_geometric Batch objects and lists of logit tensors.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/envs
- Parameters
graphs (
Batch
) – a Batch of graphs to which the logits correspond.logits (
List
[Tensor
]) – a list of tensors of shape (n, m) representing logits over a variable number of graph elements (e.g. nodes) for which there are m possible actions. n should thus be equal to the sum of the number of such elements for each graph in the Batch object. The length of the logits list should thus be equal to the number of element types (in other words there should be one tensor per type).keys (
List
[Optional
[str
,None
]]) – the keys corresponding to the Graph elements for each tensor in the logits list. Used to extract the _batch and slice attributes. For example, if the first logit tensor is a per-node action logit, and the second is a per-edge, keys could be [‘x’, ‘edge_index’]. If keys[i] is None, the corresponding logits are assumed to be graph-level (i.e. if there are k graphs in the Batch object, this logit tensor would have shape (k, m)).types (
List
[GraphActionType
]) – the action type each logit corresponds to.deduplicate_edge_index – if true, this means that the ‘edge_index’ keys have been reduced by e_i[::2] (presumably because the graphs are undirected).
- to(device)[source]¶
Move everything to the specified device.
- Parameters
device – the device to move to.
- logsoftmax()[source]¶
Compute log-probabilities given logits.
- Return type
List
[Tensor
]- Returns
a list of log-probabilities.
- sample()[source]¶
Use the Gumbel trick to sample categoricals.
if X ~ argmax(logits - log(-log(uniform(logits.shape)))) then p(X = i) = exp(logits[i]) / Z Here we have to do the argmax first over the variable number of rows of each element type for each graph in the minibatch, then over the different types (since they are mutually exclusive).
- Return type
List
[Tuple
[int
,int
,int
]]- Returns
A list of tuples specifying the action (type, row, column).
- log_prob(actions)[source]¶
The log-probability of a list of action tuples.
- Parameters
actions (
List
[Tuple
[int
,int
,int
]]) – A list of action tuples (type, row, column).- Return type
Tensor
- Returns
A tensor of shape (minibatch_size,) containing the log-probability of each action.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__doc__': 'A multi-type categorical action class compatible with generating structured actions on a graph.', '__init__': <function GraphActionCategorical.__init__>, 'detach': <function GraphActionCategorical.detach>, 'to': <function GraphActionCategorical.to>, 'logsoftmax': <function GraphActionCategorical.logsoftmax>, 'sample': <function GraphActionCategorical.sample>, 'log_prob': <function GraphActionCategorical.log_prob>, '__dict__': <attribute '__dict__' of 'GraphActionCategorical' objects>, '__weakref__': <attribute '__weakref__' of 'GraphActionCategorical' objects>, '__annotations__': {'logprobs': 'Union[List[Any], None]'}})¶
- __doc__ = 'A multi-type categorical action class compatible with generating structured actions on a graph.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- __weakref__¶
list of weak references to the object (if defined)
- class GraphBuildingEnvContext[source]¶
Bases:
object
A context environment that defines what the graphs are, how they map to and from data.
- device: str¶
- num_node_dim: int¶
- num_edge_dim: int¶
- num_cond_dim: int¶
- num_new_node_values: int¶
- num_node_attr_logits: int¶
- num_edge_attr_logits: int¶
- action_type_order: List[GraphActionType]¶
- aidx_to_graph_action(g, action_idx)[source]¶
Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction.
- Parameters
g (
Data
) – the graph to which the action is being applied.action_idx (
Tuple
[int
,int
,int
]) – the tensor indices for the corresponding action.
- Return type
- Returns
a graph action that could be applied to the original graph coressponding to g.
- graph_action_to_aidx(g, action)[source]¶
Translate a GraphAction to an action index (e.g. from a GraphActionCategorical).
- Parameters
g (
Data
) – the graph to which the action is being applied.action (
GraphAction
) – a graph action that could be applied to the original graph coressponding to g.
- Return type
Tuple
[int
,int
,int
]- Returns
the tensor indices for the corresponding action.
- graph_to_mol(g)[source]¶
Translate a graph to a molecule.
- Parameters
g (
Graph
) – the graph to translate.- Return type
Mol
- Returns
the molecule corresponding to the graph.
- graph_to_data(g)[source]¶
Convert a networkx Graph to a torch geometric Data instance.
- Parameters
g (
Graph
) – a graph instance.- Return type
Data
- Returns
the corresponding torch_geometric graph data.
- collate(graphs)[source]¶
Convert a list of torch geometric Data instances to a Batch instance. This exists so that environment contexts can set custom batching attributes, e.g. by using follow_batch.
- Parameters
graphs (
List
[Data
]) – graph instances.- Return type
Batch
- Returns
the corresponding torch_geometric batch.
- is_sane(g)[source]¶
Verifies whether a graph is sane according to the context. This can catch, e.g. impossible molecules.
- Parameters
g (
Graph
) – a graph.- Return type
bool
- Returns
true if the environment considers g to be sane.
- mol_to_graph(mol)[source]¶
Verifies whether a graph is sane according to the context. This can catch, e.g. impossible molecules.
- Parameters
mol (
Mol
) – an RDKit molecule.- Return type
- Returns
the corresponding Graph representation of that molecule.
- __annotations__ = {'action_type_order': typing.List[gt4sd.frameworks.gflownet.envs.graph_building_env.GraphActionType], 'device': <class 'str'>, 'num_cond_dim': <class 'int'>, 'num_edge_attr_logits': <class 'int'>, 'num_edge_dim': <class 'int'>, 'num_new_node_values': <class 'int'>, 'num_node_attr_logits': <class 'int'>, 'num_node_dim': <class 'int'>}¶
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.gflownet.envs.graph_building_env', '__annotations__': {'device': <class 'str'>, 'num_node_dim': <class 'int'>, 'num_edge_dim': <class 'int'>, 'num_cond_dim': <class 'int'>, 'num_new_node_values': <class 'int'>, 'num_node_attr_logits': <class 'int'>, 'num_edge_attr_logits': <class 'int'>, 'action_type_order': typing.List[gt4sd.frameworks.gflownet.envs.graph_building_env.GraphActionType]}, '__doc__': 'A context environment that defines what the graphs are, how they map to and from data.', 'aidx_to_graph_action': <function GraphBuildingEnvContext.aidx_to_graph_action>, 'graph_action_to_aidx': <function GraphBuildingEnvContext.graph_action_to_aidx>, 'graph_to_mol': <function GraphBuildingEnvContext.graph_to_mol>, 'sample_conditional_information': <function GraphBuildingEnvContext.sample_conditional_information>, 'graph_to_data': <function GraphBuildingEnvContext.graph_to_data>, 'collate': <function GraphBuildingEnvContext.collate>, 'is_sane': <function GraphBuildingEnvContext.is_sane>, 'mol_to_graph': <function GraphBuildingEnvContext.mol_to_graph>, '__dict__': <attribute '__dict__' of 'GraphBuildingEnvContext' objects>, '__weakref__': <attribute '__weakref__' of 'GraphBuildingEnvContext' objects>})¶
- __doc__ = 'A context environment that defines what the graphs are, how they map to and from data.'¶
- __module__ = 'gt4sd.frameworks.gflownet.envs.graph_building_env'¶
- __weakref__¶
list of weak references to the object (if defined)