An MLIR-based compiler that fuses array expressions in Python and emits optimized RISC-V assembly targeting a custom neural processing unit (riscv-npu).
Write array math in Python. arrax traces the expression graph, fuses operations to eliminate intermediate allocations, and compiles directly to NPU instructions — producing firmware that runs on the riscv-npu emulator.
from arrax import Array, compile
def my_kernel(A: Array, B: Array, C: Array) -> Array:
return (A + B) * C
binary = compile(my_kernel, shapes={"A": (1024,), "B": (1024,), "C": (1024,)})NumPy would execute A + B (allocate temp), then temp * C (allocate another temp). arrax fuses this into a single pass — one NPU instruction sequence, zero intermediate buffers.
Python DSL (tracing)
│
▼
array dialect (fused expression DAGs on abstract arrays)
│
▼
linalg dialect (structured ops with indexing maps)
│ ── tiling pass (fit memory budget)
│ ── fusion pass (merge adjacent elementwise ops)
│
▼
linalg on memrefs (after bufferization)
│
▼
npu dialect (1:1 NPU instructions: fvmac, fvexp, ...)
│
▼
RISC-V assembly (.S with .insn directives)
│
▼
riscv32 toolchain → ELF → riscv-npu emulator
Built on xDSL, a Python-native MLIR framework.
| Category | Operations | Notes |
|---|---|---|
| Elementwise binary | add, sub, mul, div | Array-array and scalar-array |
| Elementwise unary | neg, relu, gelu, exp | Maps to NPU activation instructions |
| Reductions | sum, max | Full-array reductions |
| Dot product | dot(A, B) | 1D, maps to FVMAC + FRSTACC |
| Matrix multiply | A @ B | 2D, via tiled dot products |
| Composite | softmax(A), rmsnorm(A, gamma) | Multi-instruction lowering patterns |
All shapes are static (known at compile time).
Requires Python 3.10+ and uv.
git clone https://github.com/paladin8/arrax.git
cd arrax
uv syncThis installs runtime dependencies and the dev dependency group (pytest) by default.
End-to-end compilation requires the RISC-V cross-compiler:
# Ubuntu/Debian (the 64-bit toolchain targets rv32 with -march=rv32imf)
sudo apt install gcc-riscv64-unknown-elf
# Or build riscv-gnu-toolchain from source for rv32imfTo run compiled firmware, install the riscv-npu emulator:
cd ../riscv-npu
uv syncfrom arrax import Array, compile
# Define a kernel
def fused_relu(A: Array, B: Array) -> Array:
return (A + B).relu()
# Compile to firmware
binary = compile(fused_relu, shapes={"A": (1024,), "B": (1024,)})
# Run on emulator
binary.run()# Run tests
uv run pytest
# Run specific test file
uv run pytest tests/test_dsl.py -v
# Stop on first failure
uv run pytest -xsrc/arrax/
├── dsl/ Python tracing DSL (Array class, DAG capture)
├── dialects/ xDSL dialect definitions (array, npu)
├── lowering/ Compiler passes (array→linalg, tiling, fusion, bufferize, linalg→npu, npu→asm)
├── codegen/ Assembly emission, firmware harness, toolchain invocation
└── pipeline.py Full pass pipeline orchestration
tests/ pytest test suite
firmware/harness/ C template + linker script for riscv-npu firmware
examples/ Usage examples
See .ai/OVERALL_DESIGN.md for the full design document covering:
- Dialect definitions and lowering strategies
- Tiling and fusion pass design
- NPU instruction mapping
- Firmware integration
- LLVM backend extension (Phase 2)
- xDSL over C++ MLIR: Faster iteration, Python-native, 1:1 compatible with MLIR IR.
- No vector dialect: The NPU's memory-to-memory model (address + length) maps to linalg's destination-passing style, not to SIMD register files.
- Composite ops retained: Softmax and RMSNorm stay as single ops in the array dialect — multi-instruction lowering benefits from recognizing them as a unit rather than decomposing early.
- Assembly first:
.insndirectives are production-quality and serve as reference implementation. LLVM backend is Phase 2.