JAX

repo: n2cholas/awesome-jax
category: Computer Science


Awesome JAX Awesome<img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="JAX Logo" align="right" height="100">

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Contents

<a name="libraries" />

Libraries

  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity. <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
    • Flax NNX - An evolution on Flax by the same team <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
    • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind. <img src="https://img.shields.io/github/stars/deepmind/dm-haiku?style=social" align="center">
    • Objax - Has an object oriented design similar to PyTorch. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
    • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
    • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads. <img src="https://img.shields.io/github/stars/google/trax?style=social" align="center">
    • Jraph - Lightweight graph neural network library. <img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center">
    • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
    • HuggingFace Transformers - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax). <img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center">
    • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
    • Scenic - A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
    • Penzai - Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
  • Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
  • EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
  • NumPyro - Probabilistic programming based on the Pyro library. <img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center">
  • Chex - Utilities to write and test reliable JAX code. <img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center">
  • Optax - Gradient processing and optimization library. <img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center">
  • RLax - Library for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center">
  • JAX, M.D. - Accelerated, differential molecular dynamics. <img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center">
  • Coax - Turn RL papers into code, the easy way. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
  • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors. <img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center">
  • cvxpylayers - Construct differentiable convex optimization layers. <img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center">
  • TensorLy - Tensor learning made simple. <img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center">
  • NetKet - Machine Learning toolbox for Quantum Physics. <img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center">
  • Fortuna - AWS library for Uncertainty Quantification in Deep Learning. <img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center">
  • BlackJAX - Library of samplers for JAX. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
  • Dynamax - Probabilistic state space models. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">

<a name="new-libraries" />

New Libraries

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
    • Equivariant MLP - Construct equivariant neural network layers. <img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center">
    • jax-resnet - Implementations and checkpoints for ResNet variants in Flax. <img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center">
    • jax-raft - JAX/Flax port of the RAFT optical flow estimator. <img src="https://img.shields.io/github/stars/alebeck/jax-raft?style=social" align="center">
    • Parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
  • Nonlinear Optimization
    • Optimistix - Root finding, minimisation, fixed points, and least squares. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
    • JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. <img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center">
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
  • flowjax - Distributions and normalizing flows built as equinox modules. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
  • flaxdiff - Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs) <img src="https://img.shields.io/github/stars/AshishKumar4/FlaxDiff?style=social" align="center">
  • jax-flows - Normalizing flows in JAX. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
  • jax-cosmo - Differentiable cosmology library. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
  • efax - Exponential Families in JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs. <img src="https://img.shields.io/github/stars/PhilipVinc/mpi4jax?style=social" align="center">
  • imax - Image augmentations and transformations. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
  • FlaxVision - Flax version of TorchVision. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
  • Oryx - Probabilistic programming language based on program transformations.
  • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV - A photovoltaic simulator with automatic differentation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
  • jaxlie - Lie theory library for rigid body transformations and optimization. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
  • BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
  • flaxmodels - Pretrained models for Jax/Flax. <img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center">
  • CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
  • exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
  • PIX - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/deepmind/dm_pix?style=social" align="center">
  • bayex - Bayesian Optimization powered by JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
  • JaxDF - Framework for differentiable simulators with arbitrary discretizations. <img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center">
  • tree-math - Convert functions that operate on arrays into functions that operate on PyTrees. <img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center">
  • jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX. <img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center">
  • PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX. <img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center">
  • EvoJAX - Hardware-Accelerated Neuroevolution <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
  • evosax - JAX-Based Evolution Strategies <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
  • SymJAX - Symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
  • mcx - Express & compile probabilistic programs for performant inference. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
  • Einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center">
  • ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
  • Diffrax - Numerical differential equation solvers in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
  • tinygp - The tiniest of Gaussian process libraries in JAX. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
  • gymnax - Reinforcement Learning Environments with the well-known gym API. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
  • Mctx - Monte Carlo tree search algorithms in native JAX. <img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center">
  • KFAC-JAX - Second Order Optimization with Approximate Curvature for NNs. <img src="https://img.shields.io/github/stars/deepmind/kfac-jax?style=social" align="center">
  • TF2JAX - Convert functions/graphs to JAX functions. <img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center">
  • jwave - A library for differentiable acoustic simulations <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
  • GPJax - Gaussian processes in JAX.
  • Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
  • Eqxvision - Equinox version of Torchvision. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
  • JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper). <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
  • econpizza - Solve macroeconomic models with hetereogeneous agents using JAX. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
  • SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation). <img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center">
  • jax-tqdm - Add a tqdm progress bar to JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
  • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗safetensors. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
  • Kernex - Differentiable stencil decorators in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
  • MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs. <img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center">
  • Pax - A Jax-based machine learning framework for training large scale models. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
  • Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects. <img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center">
  • purejaxrl - Vectorisable, end-to-end RL algorithms in JAX. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
  • Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
  • SCICO - Scientific computational imaging in JAX. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
  • Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware. <img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center">
  • Brain Dynamics Programming Ecosystem
    • BrainPy - Brain Dynamics Programming in Python. <img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center">
    • brainunit - Physical units and unit-aware mathematical system in JAX. <img src="https://img.shields.io/github/stars/chaobrain/brainunit?style=social" align="center">
    • dendritex - Dendritic Modeling in JAX. <img src="https://img.shields.io/github/stars/chaobrain/dendritex?style=social" align="center">
    • brainstate - State-based Transformation System for Program Compilation and Augmentation. <img src="https://img.shields.io/github/stars/chaobrain/brainstate?style=social" align="center">
    • braintaichi - Leveraging Taichi Lang to customize brain dynamics operators. <img src="https://img.shields.io/github/stars/chaobrain/braintaichi?style=social" align="center">
  • OTT-JAX - Optimal transport tools in JAX. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
  • QDax - Quality Diversity optimization in Jax. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
  • JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
  • Pgx - Vectorized board game environments for RL with an AlphaZero example. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
  • EasyDeL - EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
  • XLB - A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning. <img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center">
  • dynamiqs - High-performance and differentiable simulations of quantum systems with JAX. <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
  • foragax - Agent-Based modelling framework in JAX. <img src="https://img.shields.io/github/stars/i-m-iron-man/Foragax?style=social" align="center">
  • tmmax - Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research <img src="https://img.shields.io/github/stars/bahremsd/tmmax" align="center">
  • Coreax - Algorithms for finding coresets to compress large datasets while retaining their statistical properties. <img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center">
  • NAVIX - A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX <img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center">
  • FDTDX - Finite-Difference Time-Domain Electromagnetic Simulations in JAX <img src="https://img.shields.io/github/stars/ymahlau/fdtdx?style=social" align="center">
  • DiffeRT - Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem. <img src="https://img.shields.io/github/stars/jeertmans/DiffeRT?style=social" align="center">
  • JAX-in-Cell - Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields <img src="https://img.shields.io/github/stars/uwplasma/JAX-in-Cell?style=social" align="center">
  • kvax - A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism. <img src="https://img.shields.io/github/stars/nebius/kvax?style=social" align="center">
  • astronomix - differentiable (magneto)hydrodynamics for astrophysics in JAX <img src="https://img.shields.io/github/stars/leo1200/astronomix?style=social" align="center">
  • vivsim - Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method. <img src="https://img.shields.io/github/stars/haimingz/vivsim?style=social" align="center">
  • MBIRJAX - High-performance tomographic reconstruction. <img src="https://img.shields.io/github/stars/cabouman/mbirjax?style-social" align="center">
  • torchax - torchax is a library for Jax to interoperate with model code written in PyTorch.<img src="https://img.shields.io/github/stars/google/torchax?style=social" align="center">

<a name="models-and-projects" />

Models and Projects

JAX

Flax

Haiku

Trax

  • Reformer - Implementation of the Reformer (efficient transformer) architecture.

NumPyro

Equinox

<a name="videos" />

Videos

<a name="papers" />

Papers

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

<a name="tutorials-and-blog-posts" />

Tutorials and Blog Posts

<a name="books" />

Books

<a name="community" />

Community

Contributing

Contributions welcome! Read the contribution guidelines first.

[[curator]]
I'm the Curator. I can help you navigate, organize, and curate this wiki. What would you like to do?