gt4sd.frameworks.gflownet.ml.models.graph_transformer module

Summary

Classes:

GraphTransformer

GraphTransformer.

GraphTransformerGFN

GraphTransformerGFN.

Functions:

mlp

Reference

mlp(n_in, n_hid, n_out, n_layer, act=<class 'torch.nn.modules.activation.LeakyReLU'>)[source]
class GraphTransformer(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]

Bases: Module

GraphTransformer.

__init__(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]

Construct GraphTransformer.

Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/models.

Parameters
  • configuration (Dict[str, Any]) – model configuration.

  • context (GraphBuildingEnvContext) – context environment.

  • x_dim – dimension of input node features.

  • e_dim – dimension of input edge features.

  • g_dim – dimension of input graph features.

  • num_emb – dimension of embedding layer.

  • num_layers – number of layers in the graph transformer.

  • num_heads – number of heads in the graph transformer.

forward(g, cond)[source]

Forward pass.

Parameters
  • g (Batch) – graph data.

  • cond (Tensor) – conditioning.

Return type

Tuple[Tensor, Tensor]

Returns

embeddings and pooled features.

__annotations__ = {}
__doc__ = 'GraphTransformer.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'
class GraphTransformerGFN(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]

Bases: Module

GraphTransformerGFN.

__init__(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]

Construct GraphTransformerGFN.

Parameters
  • configuration (Dict[str, Any]) – model configuration.

  • context (GraphBuildingEnvContext) – context environment.

  • num_emb – dimension of embedding layer.

  • num_layers – number of layers in the graph transformer.

  • num_heads – number of heads in the graph transformer.

forward(g, cond)[source]

Forward pass. Given a graph and a conditioning, return the action logits and rewards.

Parameters
  • g (Batch) – graph data.

  • cond (Tensor) – conditioning.

Returns

categorical action and rewards.

__annotations__ = {}
__doc__ = 'GraphTransformerGFN.'
__module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'