gt4sd.frameworks.gflownet.ml.models.graph_transformer module¶
Summary¶
Classes:
GraphTransformer. |
|
GraphTransformerGFN. |
Functions:
Reference¶
- class GraphTransformer(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]¶
Bases:
Module
GraphTransformer.
- __init__(configuration, context, x_dim=64, e_dim=64, g_dim=64, num_emb=64, num_layers=3, num_heads=2)[source]¶
Construct GraphTransformer.
Code adapted from: https://github.com/recursionpharma/gflownet/tree/trunk/src/gflownet/models.
- Parameters
configuration (
Dict
[str
,Any
]) – model configuration.context (
GraphBuildingEnvContext
) – context environment.x_dim – dimension of input node features.
e_dim – dimension of input edge features.
g_dim – dimension of input graph features.
num_emb – dimension of embedding layer.
num_layers – number of layers in the graph transformer.
num_heads – number of heads in the graph transformer.
- forward(g, cond)[source]¶
Forward pass.
- Parameters
g (
Batch
) – graph data.cond (
Tensor
) – conditioning.
- Return type
Tuple
[Tensor
,Tensor
]- Returns
embeddings and pooled features.
- __annotations__ = {}¶
- __doc__ = 'GraphTransformer.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'¶
- class GraphTransformerGFN(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]¶
Bases:
Module
GraphTransformerGFN.
- __init__(configuration, context, num_emb=64, num_layers=3, num_heads=2)[source]¶
Construct GraphTransformerGFN.
- Parameters
configuration (
Dict
[str
,Any
]) – model configuration.context (
GraphBuildingEnvContext
) – context environment.num_emb – dimension of embedding layer.
num_layers – number of layers in the graph transformer.
num_heads – number of heads in the graph transformer.
- forward(g, cond)[source]¶
Forward pass. Given a graph and a conditioning, return the action logits and rewards.
- Parameters
g (
Batch
) – graph data.cond (
Tensor
) – conditioning.
- Returns
categorical action and rewards.
- __annotations__ = {}¶
- __doc__ = 'GraphTransformerGFN.'¶
- __module__ = 'gt4sd.frameworks.gflownet.ml.models.graph_transformer'¶