GraPhens is a phenotype graph construction library and training pipeline. It builds graphs from HPO phenotype sets and supports dataset generation and model training in the current Keras+JAX stack.
GraPhens uses Keras+JAX for dataset generation and model training, with fixed-shape NPZ graph data consumed through Keras data sequences.
from graphens import GraPhens
# Create a graph from phenotypes
graphens = GraPhens()
graph = graphens.create_graph_from_phenotypes(["HP:0001250", "HP:0001251"]) # Seizure phenotypes
# Export to a format you need
graph_json = graphens.export_graph(graph, format="json")# Search for phenotypes by keyword
seizure_phenotypes = graphens.phenotype_lookup("seizure")
for phenotype in seizure_phenotypes[:3]:
print(f"{phenotype.id} - {phenotype.name}")# Chain methods for clear, readable configuration
graphens = (GraPhens()
.with_embedding_model("openai", "text-embedding-3-small")
.with_augmentation(include_ancestors=True)
.with_visualization(enabled=True))
# Create and visualize your graph
graph = graphens.create_graph_from_phenotypes(phenotype_ids)
graphens.visualize(graph=graph, phenotypes=graph.metadata["phenotypes"])# Load biomedical embeddings directly from file
graphens = GraPhens().with_lookup_embeddings("data/embeddings/hpo_biobert.pkl")
# Create graph with domain-specific embeddings
graph = graphens.create_graph_from_phenotypes(phenotype_ids)# NetworkX
nx_graph = graphens.export_graph(graph, format="networkx")
# JSON
graphens.export_graph(graph, format="json", output_path="graph.json")patient_data = {
"patient_1": ["HP:0001250", "HP:0002066"],
"patient_2": ["HP:0000407", "HP:0001263"],
# ... more patients
}
# Create graphs for multiple patients
patient_graphs = graphens.create_graphs_from_multiple_patients(patient_data)pip install graphensCurrent training dataset path:
Simulation JSON -> NPZ shards -> Keras/JAX loader
- Dataset builder:
src/simulation/phenotype_simulation/create_hpo_dataset.py- Two-pass process: collect
max_nodes/max_edgesstatistics, then write padded NPZ shards.
- Two-pass process: collect
- NPZ shard writer:
src/simulation/phenotype_simulation/jax_npz_writer.py- Stores fixed-shape arrays and masks:
x,node_mask,edge_index,edge_mask,y.
- Stores fixed-shape arrays and masks:
- Dataset loaders:
training/datasets/jax_npz_graph_dataset.pytraining/datasets/keras_npz_sequence.py
Example command:
python -m src.simulation.phenotype_simulation.create_hpo_dataset --input <simulated_json> --output-dir <dataset_dir> --shard-size 2048 --create-splits
Validate runtime stack:
KERAS_BACKEND=jax python scripts/validate_jax_stack.py
See docs/dataset_keras_jax.md for schema, workflow, and dependencies.
- Graph samples are pre-sharded as fixed-shape NPZ tensors and consumed via Keras
Sequenceformodel.fit. - Batches use static padded tensors with explicit masks (
node_mask,edge_mask) so JAX/XLA can compile stable programs.
- Set backend to JAX:
export KERAS_BACKEND=jax. - Use larger
--shard-sizeand trainingbatch_sizewhen memory allows to improve device utilization. - Keep graph tensors fixed-shape (already done via NPZ + masks) so XLA can compile stable TPU programs.
- Avoid frequent shape changes between runs; keep
max_nodes,max_edges, and model config stable for better compile reuse.
GraPhens prioritizes clear APIs, reproducible preprocessing, and explicit training data contracts across the graph construction and model training pipeline.
# Examples in the repo show you everything you need
from examples import quick_start, custom_embeddings, visualization_demo
# Methods have helpful docstrings
help(GraPhens.create_graph_from_phenotypes)# Custom embedding models
graphens.with_embedding_model("tfidf", max_features=512)
# Save configurations for reproducibility
graphens.save_config("my_config.json")
loaded_graphens = GraPhens().with_config_from_file("my_config.json")See the examples directory for end-to-end usage examples.