gt4sd.frameworks.gflownet.ml.models.graph_transformer module¶
Summary¶
Classes:
GraphTransformer.  | 
|
GraphTransformerGFN.  | 
Functions:
Reference¶
- class GraphTransformer(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]¶
 Bases:
ModuleGraphTransformer.
- __init__(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]¶
 Construct GraphTransformer.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/models.
- Parameters
 configuration (
Dict[str,Any]) – model configuration.context (
GraphBuildingEnvContext) – context environment.x_dim – dimension of input node features.
e_dim – dimension of input edge features.
g_dim – dimension of input graph features.
num_emb – dimension of embedding layer.
num_layers – number of layers in the graph transformer.
num_heads – number of heads in the graph transformer.
- forward(g, cond)[source]¶
 Forward pass.
- Parameters
 g (
Batch) – graph data.cond (
Tensor) – conditioning.
- Return type
 Tuple[Tensor,Tensor]- Returns
 embeddings and pooled features.
- __annotations__ = {}¶
 
- __doc__ = 'GraphTransformer.'¶
 
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'¶
 
- class GraphTransformerGFN(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]¶
 Bases:
ModuleGraphTransformerGFN.
- __init__(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]¶
 Construct GraphTransformerGFN.
- Parameters
 configuration (
Dict[str,Any]) – model configuration.context (
GraphBuildingEnvContext) – context environment.num_emb – dimension of embedding layer.
num_layers – number of layers in the graph transformer.
num_heads – number of heads in the graph transformer.
- forward(g, cond)[source]¶
 Forward pass. Given a graph and a conditioning, return the action logits and rewards.
- Parameters
 g (
Batch) – graph data.cond (
Tensor) – conditioning.
- Returns
 categorical action and rewards.
- __annotations__ = {}¶
 
- __doc__ = 'GraphTransformerGFN.'¶
 
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'¶