A reusable collection of high-performance neural network layers and models in JAX, aiming to match and exceed the capabilities available in the PyTorch ecosystem.
JAXgarden was created to provide the JAX ecosystem with a comprehensive library of well-documented, thoroughly tested, and numerically accurate implementations of neural network layers and models. The project aims to:
- Provide both functional APIs and Flax NNX wrappers for maximum flexibility
 - Ensure seamless integration with the broader JAX ecosystem, especially Flax
 - Facilitate easy upstreaming of implementations to core libraries
 - Maintain rigorous testing and documentation standards
 - Match or exceed the performance of equivalent PyTorch implementations
 
Initially started within the ML GDE group, the project began with a high-performance MultiHeadAttention implementation supporting various attention backends, with plans to expand to more layers and models.
- MultiHeadAttention: A Flax NNX-compatible implementation with support for different attention backends.
- Supports JAX's native Flash Attention implementation through cuDNN
 - Seamlessly integrates with Flax NNX's module system
 - Provides a simple interface for switching between attention implementations
 
 
# Install from source
git clone https://github.com/ml-gde/jax-layers.git
cd jax-layers
pip install -e .from jaxgarden import LlamaConfig, LlamaForCausalLM, Tokenizer
from flax import nnx
# HF repo id of the LLaMA variant that you want to use
model_id = "meta-llama/Llama-3.2-1B"
# initialize the LLaMA architecture
config = LlamaConfig()
model = LlamaForCausalLM(config, rngs=nnx.Rngs(0))
# This is a one-liner to download HF checkpoint from HuggingFace Hub,
# convert it to jaxgarden format,
# save it in an Orbax checkpoint,
# and then remove the HF checkpoint.
model.from_hf(model_id)
# this works just like `transformers.AutoTokenizer`,
# but without the dependency of the whole `transformers` library.
# Instead, we simply extend `tokenizers` package and add some cnvenience code for JAX.
tokenizer = Tokenizer.from_pretrained(model_id)
text = "The meaning of life is"
model_inputs = tokenizer.encode(text)
output = model.generate(**model_inputs, max_length=20, do_sample=True)
output_text = tokenizer.decode(output)
print(output_text)import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jaxgarden.attention import MultiHeadAttention
# Create a MultiHeadAttention module with Flash Attention support
attention = MultiHeadAttention(
    num_heads=8,
    in_features=512,
    implementation="cudnn",  # Use cuDNN's Flash Attention if available
    rngs=nnx.Rngs(0),
)
# Create input data
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2, 128, 512))  # (batch, seq_length, hidden_dim)
# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128)))  # (batch, 1, q_len, kv_len)
# Apply the model
output = attention(x, mask=mask)import jax
import jax.numpy as jnp
import flax.linen as nn
from jaxgarden.attention.rope_multi_head_attention import RoPEMultiHeadAttention
# 1. Setup
key = jax.random.PRNGKey(0)
batch_size, seq_len = 2, 16
num_heads, head_dim = 4, 32
embed_dim = num_heads * head_dim
x = jnp.ones((batch_size, seq_len, embed_dim))
# 2. Instantiate Module
attention = RoPEMultiHeadAttention(num_heads=num_heads, head_dim=head_dim)
# 3. Initialize Parameters
params = attention.init(key, x)['params']
# 4. Apply Module (Forward Pass)
output = attention.apply({'params': params}, x)import jax
import jax.numpy as jnp
from jaxgarden.functional import dot_product_attention
# Create random query, key, value tensors
key = jax.random.PRNGKey(0)
query = jax.random.normal(key, (2, 128, 8, 64))  # (batch, seq_len, heads, head_dim)
key_tensor = jax.random.normal(key, (2, 128, 8, 64))
value = jax.random.normal(key, (2, 128, 8, 64))
# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128)))  # (batch, 1, q_len, kv_len)
# Apply dot product attention with Flash Attention implementation
output = dot_product_attention(
    query=query,
    key=key_tensor,
    value=value,
    mask=mask,
    implementation="cudnn",  # Use cuDNN's Flash Attention implementation
)- Please fork the repository to your account first.
 - Follow the instructions below.
 
# Clone the repository
git clone https://github.com/yourusername/jax-layers.git
cd jax-layers
# Install development dependencies
pip install -e ".[dev]"This project uses pre-commit hooks to ensure code quality and consistency. Pre-commit automatically runs linting and formatting tools (such as ruff) before each commit, helping to catch issues early.
# Install Pre-commit Hooks
pre-commit install
# Run Pre-commit on All Files
pre-commit run --all-filesEvery time you attempt to commit, pre-commit automatically runs the configured hooks (e.g., ruff). If any issues are detected, the commit will be blocked until they are resolved.
The project maintains a comprehensive test suite to ensure correctness and numerical accuracy:
# Run all tests
pytest
# Run tests with coverage
pytest tests/ --cov=jaxgarden
# Run specific test file
pytest tests/test_multi_head_attention.pyWe maintain high code quality standards through automated checks:
# Run linting
ruff check .
# Run type checking
mypy jaxgarden
# Run tests
pytestDocumentation is automatically generated from docstrings:
# Build documentation
cd docs
make htmlSince JAX doesn't support CUDA on Windows natively, we provide a development container configuration:
- Install Docker Desktop with WSL 2 backend
 - Install NVIDIA Container Toolkit
 - Install Visual Studio Code with the Remote - Containers extension
 - Open the project in VS Code
 - Click the green icon in the bottom-left corner and select "Reopen in Container"
 
The container provides:
- Python 3.10
 - CUDA 12.4 with cuDNN 9
 - JAX with CUDA support
 - All dependencies from your pyproject.toml
 
See .devcontainer/README.md for more details.
Contributions are more than welcome! Whether it's:
- Adding new layer implementations
 - Improving documentation
 - Adding tests
 - Reporting bugs
 - Suggesting improvements
 
Please feel free to open issues and pull requests.
This project is licensed under the MIT License - see the LICENSE file for details.
Google AI Developer Programs team supported this work by providing Google Cloud Credit.
- Thanks to the JAX and Flax teams for their excellent libraries.
 - Special thanks to the ML GDE group for initiating this project.
 
