Skip to content

Add JAX inference backend with multiple model architectures#2250

Draft
efsiatras wants to merge 9 commits intooumi-ai:mainfrom
efsiatras:add-jax-backend
Draft

Add JAX inference backend with multiple model architectures#2250
efsiatras wants to merge 9 commits intooumi-ai:mainfrom
efsiatras:add-jax-backend

Conversation

@efsiatras
Copy link
Copy Markdown
Contributor

@efsiatras efsiatras commented Mar 8, 2026

Summary

Adds a JAX inference backend to Oumi, enabling high-performance LLM inference using JAX implementations from jax-llm-examples repository.

Models

Architecture Models Features
Llama 3 3.1 8B/70B/405B Instruct GQA, RoPE, INT8 quantization
Llama 4 Scout, Maverick MoE, NoPE layers, chunked attention
DeepSeek R1 R1 Multi-head Latent Attention (MLA), MoE
Qwen 3 0.6B through 235B (dense + MoE) GQA/MLA, MoE routing
Kimi K2 K2 (1T params) MLA + MoE
GPT-OSS 20B, 120B Sliding window attention, MoE
Nemotron 3 Nano (8B) Hybrid Mamba-Transformer, SSM layers

What's included

Core integration:

  • JAXInferenceEngine extending BaseInferenceEngine: full prefill-then-decode inference loop with architecture-aware mesh creation, tensor parallelism, and INT8 quantization support
  • JAXModelManager: unified download, convert, and load pipeline using HuggingFace Hub + orbax checkpoints
  • Model registry with model entries across architectures
  • CLI interface (python -m oumi.models.experimental.jax_models [list|download|convert|run])
  • JAX utility functions for device management, tensor conversion, and checkpoint I/O

Upstream code (from jax-llm-examples):

  • Model implementations for all 7 architectures (identical to upstream, reformatted only)
  • Custom TPU kernels (ragged attention, ragged dot for MoE)
  • Per-model download, conversion, and quantization scripts
  • Per-model unit tests with numerical validation

Tests:

  • Unit tests for inference engine (architecture resolution, parameter validation)
  • Integration tests for all 7 model architectures (prefill/decode, cache init, quantization)
  • Verification script that tests all 7 models with random weights on CPU
  • End-to-end generation test script

Usage

from oumi.inference import JAXInferenceEngine
from oumi.core.configs import ModelParams, GenerationParams

engine = JAXInferenceEngine(
    model_params=ModelParams(model_name="path/to/jax/checkpoint"),
    generation_params=GenerationParams(max_new_tokens=64),
)

- Integrated JAX-based model support with model management, conversion utilities, and CLI tools
- Added support for Llama3/4, Qwen3, DeepSeek R1, and Kimi K2 architectures
- Implemented automatic HuggingFace-to-JAX conversion pipeline
- Added proper attribution headers for JAX Authors (jax-llm-examples) and Oumi
- Created unified CLI interface for model downloads and conversions
- Added registry and manager for JAX model platform
# Conflicts:
#	scripts/examples/evaluation/custom_evaluation.py
#	src/oumi/core/configs/inference_engine_type.py
#	src/oumi/core/types/proto/generated/conversation_pb2.pyi
@gitar-bot
Copy link
Copy Markdown

gitar-bot bot commented Mar 8, 2026

Important

Upgrade your plan to unlock code review, CI analysis, custom rules, and more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant