Skip to content

lockwo/awesome-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Awesome JAX AwesomeJAX Logo

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Be sure to check out our (experimental) interactive web version: https://lockwo.github.io/awesome-jax/.

Why do we need another "awesome-jax" list? Existing ones are inactive, and this is directly based on the no longer active Awesome JAX repos https://github.com/n2cholas/awesome-jax/ and https://github.com/mhlr/awesome-jax.

Contents

Libraries

  • Neural Network Libraries

    • Flax - Flax is a neural network library for JAX that is designed for flexibility.
    • Equinox - Elegant easy-to-use neural networks + scientific computing in JAX.
  • Reinforcement Learning Libraries

    • JaxMARL - Multi-Agent Reinforcement Learning with JAX.
    • Algorithms
      • cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG).
      • rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents.
      • purejaxrl - Really Fast End-to-End Jax RL Implementations.
      • Mava - 🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.
      • Stoix - 🏛️A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX • End-to-End JAX RL.
    • Environments
      • pgx - Vectorized RL game environments in JAX.
      • jumanji - 🕹️ A diverse suite of scalable reinforcement learning environments in JAX.
      • gymnax - RL Environments in JAX 🌍.
      • brax - Massively parallel rigidbody physics simulation on accelerator hardware.
      • craftax - (Crafter + NetHack) in JAX. ICML 2024 Spotlight.
      • navix - Accelerated minigrid environments with JAX.
      • JaxGCRL - Goal-Conditioned Reinforcement Learning with JAX.
      • Kinetix - Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral.
      • XLand-MiniGrid - JAX-accelerated Meta-Reinforcement Learning Environments Inspired by XLand and MiniGrid 🏎️.
  • Natural Language Processing Libraries

    • levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax.
    • maxtext - A simple, performant and scalable Jax LLM!
    • EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
  • JAX Utilities Libraries

    • jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.
    • chex - a library of utilities for helping to write reliable JAX code.
    • mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡.
    • jax-tqdm - Add a tqdm progress bar to your JAX scans and loops.
    • JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs.
    • penzai - A JAX research toolkit for building, editing, and visualizing neural networks.
    • orbax - Orbax provides common checkpointing and persistence utilities for JAX users.
  • Computer Vision Libraries

    • Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond.
    • dm_pix - PIX is an image processing library in JAX, for JAX.
  • Distributions, Sampling, and Probabilistic Libraries

    • distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors.
    • distrax - a lightweight library of probability distributions and bijectors.
    • flowjax - Distributions, bijections and normalizing flows using Equinox and JAX.
    • blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
    • bayex - Minimal Implementation of Bayesian Optimization in JAX.
    • efax - Exponential families for JAX.
    • jaxns - Probabilistic Programming and Nested sampling in JAX.
  • GPJax - Gaussian processes in JAX.

  • tinygp - The tiniest of Gaussian Process libraries.

  • Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

  • probdiffeq - Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.

  • jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics.

  • lineax - Linear solvers in JAX and Equinox.

  • optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox.

  • sympy2jax - Turn SymPy expressions into trainable JAX expressions.

  • quax - Multiple dispatch over abstract array types in JAX.

  • interpax - Interpolation and function approximation with JAX.

  • quadax - Numerical quadrature with JAX.

  • optax - Optax is a gradient processing and optimization library for JAX.

  • dynamax - State Space Models library in JAX.

  • dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers).

  • scico - Scientific Computational Imaging COde.

  • exojax - 🐈 Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt.

  • PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX.

  • evosax - Evolution Strategies in JAX 🦎.

  • evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs.

  • mctx - Monte Carlo tree search in JAX.

  • kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX.

  • jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs.

  • jax_cosmo - A differentiable cosmology library in JAX.

  • jaxlie - Rigid transforms + Lie groups in JAX.

  • ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.

  • XLB - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML.

  • EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX.

  • QDax - Accelerated Quality-Diversity.

  • paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

  • econpizza - Solve nonlinear heterogeneous agent models.

  • fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.

  • neural-tangents - Fast and Easy Infinite Neural Networks in Python.

  • jax-fem - Differentiable Finite Element Method with JAX.

  • veros - The versatile ocean simulator, in pure Python, powered by JAX.

  • JAXFLUIDS - Differentiable Fluid Dynamics Package.

  • klujax - Solve sparse linear systems in JAX using the KLU algorithm.

  • coreax - A library for coreset algorithms, written in Jax for fast execution and GPU support.

  • fdtdx - Electromagnetic FDTD Simulations in JAX.

  • Jaxley - Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU.

  • torch2jax - Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.

  • cola - Compositional Linear Algebra.

  • laplax - Laplace approximations in JAX.

  • thrml - Thermodynamic Hypergraphical Model Library.

  • astronomix - differentiable (magneto)hydrodynamics for astrophysics in JAX.

  • memax - Deep memory and sequence models in JAX.

  • JAXMg - JAXMg: A multi-GPU linear solver in JAX.

  • exponax - Efficient Differentiable n-d PDE Solvers in JAX.

Up and Coming Libraries

  • traceax - Stochastic trace estimation using JAX.
  • graphax - Cross-Country Elimination in JAX.
  • cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC.
  • jumpax - Jump Processes in JAX.
  • driftax - Drifting Generative Models - JAX/Flax implmentation of Generative Modeling via Drifting.
  • ParamRF - Parametric Radio Frequency Modelling, Fitting and Sampling.

Inactive Libraries

  • Haiku - JAX-based neural network library.
  • jraph - A Graph Neural Network Library in Jax.
  • SymJAX - symbolic CPU/GPU/TPU programming.
  • coax - Modular framework for Reinforcement Learning in python.
  • eqxvision - A Python package of computer vision models for the Equinox ecosystem.
  • jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX.
  • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗safetensors.
  • kernex - Stencil computations in JAX.
  • lorax - LoRA for arbitrary JAX models and functions.
  • mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
  • einshape - DSL-based reshaping library for JAX and other frameworks.
  • sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX.
  • deltapv - A photovoltaic simulator with automatic differentiation.
  • cr-sparse - Functional models and algorithms for sparse signal processing.
  • flaxvision - A selection of neural network models ported from torchvision for JAX & Flax.
  • imax - Image augmentation library for Jax.
  • jax-unirep - Reimplementation of the UniRep protein featurization model.
  • parallax - Immutable Torch Modules for JAX.
  • jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
  • elegy - A High Level API for Deep Learning in JAX.
  • objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base.
  • jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

Models and Projects

  • whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
  • esm2quinox - An implementation of ESM2 in Equinox+JAX.

Tutorials and Blog Posts

Videos

Community

Releases

No releases published

Packages

 
 
 

Contributors