A Graph Neural Ordinary Differential Equation (ODE) framework for modeling and predicting gene expression trajectories over time. RiTINI combines Graph Attention Networks (GATs) with Neural ODEs to capture the continuous temporal dynamics of gene regulatory networks.
RiTINI (Regulatory Temporal Interaction Network Inference) is designed to:
- Model temporal gene expression data using graph neural networks
- Learn attention-based gene regulatory networks from trajectory data
- Predict future gene expression states through continuous-time ODEs
- Visualize learned attention patterns and regulatory interactions
- Python >= 3.10
- PyTorch >= 2.8.0
- PyTorch Geometric >= 2.6.1
- Clone the repository:
git clone [email protected]:KrishnaswamyLab/RiTINI.git
cd RiTINI- Install dependencies using uv (recommended):
uv venv
source .venv/bin/activate
uv syncOr using pip:
pip install -e .To run preprocessing and training sequentially:
python main.pyEach step can be run as a standalone script, which is useful for iterating on specific stages.
python preprocess.pyPreprocessing performs:
- Load raw trajectory data from
.npyfile - Filter genes based on interest genes list
- Compute prior adjacency matrix (uses Granger Causality by default)
- Average all trajectories into a single representative trajectory
- Normalize features using z-score normalization
- Save preprocessed data to
data/preprocessed/directory
python train.pyTraining performs:
- Load preprocessed data
- Create temporal graph dataset with sliding time windows
- Initialize and train the RiTINI model
- Apply graph regularization based on prior network
- Save best model based on total loss
- Store training history with all loss components
python gene_inference_viz.pypython gene_trajectory_viz.pyThe preprocessing script requires three input files:
-
Trajectory Data (
raw_trajectory_file):.npyfile containing gene expression trajectories- Shape:
(n_timepoints, n_trajectories, n_genes)
- Shape:
-
Gene Names (
raw_gene_names_file):.txtfile with names of all genes -
Interest Genes (
interest_genes_file):.txtfile with subset of genes to analyze
raw_trajectory_file = 'data/raw/traj_data.npy'
raw_gene_names_file = 'data/raw/gene_names.txt'
interest_genes_file = 'data/raw/interest_genes.txt'python test_toy_data_ritini.pyRiTINI consists of three main components:
-
GAT Convolutional Layer (gatConvwithAttention.py)
- Multi-head attention mechanism
- Learns edge weights between genes
- Aggregates neighbor information
-
Graph Differential Equation (graphDifferentialEquation.py)
- Wraps GAT layer as ODE function
- Computes derivatives for continuous-time evolution
-
ODE Block (ode.py)
- Integrates dynamics using Neural ODE solvers
- Supports multiple integration methods (RK4, Dopri5, etc.)
- Optional adjoint method for memory-efficient backprop
- Input features: 1 (gene expression value per node)
- Output features: 1 (predicted expression value)
- Architecture: Temporal Graph Attention Network
- Attention mechanism: Multi-head attention with configurable heads
n_heads = 1 # Number of attention heads
feat_dropout = 0.1 # Feature dropout rate
attn_dropout = 0.1 # Attention dropout rate
activation_func = nn.Tanh() # Activation function
residual = False # Use residual connections
negative_slope = 0.2 # LeakyReLU negative slopeThe RiTINI model uses Neural ODEs for continuous-time modeling:
ode_method = 'rk4' # ODE solver (rk4, dopri5, etc.)
atol = 1e-3 # Absolute tolerance
rtol = 1e-4 # Relative tolerance
use_adjoint = False # Use adjoint method for memory efficiencyn_epochs = 200 # Number of training epochs
learning_rate = 0.001 # Initial learning rate
batch_size = 4 # Batch size
time_window = 5 # Temporal window length (None = all timepoints)graph_reg_weight = 0.1 # Weight for graph regularization lossTotal loss = Feature reconstruction loss + (graph_reg_weight × Graph regularization loss)
lr_factor = 0.5 # LR reduction factor
lr_patience = 10 # Epochs to wait before reducing LRKey hyperparameters in training:
# Data parameters
n_top_genes = 20 # Number of genes to model
time_window = 5 # Length of temporal sequences
batch_size = 4
# Model parameters
n_heads = 1 # GAT attention heads
feat_dropout = 0.1 # Feature dropout rate
attn_dropout = 0.1 # Attention dropout rate
activation_func = nn.Tanh()
residual = False # Residual connections
# Training parameters
n_epochs = 200
learning_rate = 0.001
lr_factor = 0.5 # Scheduler reduction factor
lr_patience = 10 # Scheduler patienceExpected format: (n_timepoints, n_trajectories, n_genes)
from ritini.data.trajectory_loader import prepare_trajectories_data
data = prepare_trajectories_data(
trajectory_file='data/trajectories/traj_data.pkl',
prior_graph_file='data/trajectories/prior_graph.pkl',
gene_names_file='data/trajectories/gene_names.txt',
n_top_genes=20,
use_mean_trajectory=True
)Run the test suite:
# Test on toy data
pytest tests/test_toy_data_ritini.py
# Test on real data
pytest tests/test_real_data_gat.py
# Run all tests
pytest tests/Core dependencies:
torch >= 2.8.0- Deep learning frameworktorch-geometric >= 2.6.1- Graph neural network librarytorchdiffeq >= 0.2.3- Neural ODE solversnetworkx >= 3.0- Graph manipulationnumpy >= 1.24.0- Numerical computingmatplotlib >= 3.10.6- Plottingseaborn >= 0.13.2- Statistical visualizationscanpy >= 1.11.4- Single-cell analysisscikit-misc >= 0.5.1- Scientific computing utilities
See pyproject.toml for full dependency list.
If you use RiTINI in your research, please cite:
@misc{https://doi.org/10.48550/arxiv.2306.07803,
doi = {10.48550/ARXIV.2306.07803},
url = {https://arxiv.org/abs/2306.07803},
author = {Bhaskar, Dhananjay and Magruder, Sumner and De Brouwer, Edward and Venkat, Aarthi and Wenkel, Frederik and Wolf, Guy and Krishnaswamy, Smita},
keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Inferring dynamic regulatory interaction graphs from time series data with perturbations},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}Yale License