Skip to content

SalamanderXing/graph-transformer-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Graph Transformer

This is a JAX/Flax version of Lucidrains' Graph Transfomer in PyTorch. Notice this implementation does not yet support positional embeddings.

from jax import random

from graph_transformer import GraphTransformer


model, params = GraphTransformer.initialize(
    random.PRNGKey(0),
    number_of_nodes=9,
    num_layers=2,
    in_edge_features=10,
    in_node_features=10,
)

print(model)