PyTensor: Python library for defining, optimizing, and evaluating mathematical expressions with multi-dimensional arrays. Focus on hackable graph analysis and manipulation. Supports C, JAX, and Numba compilation backends. ~27MB, 492 Python files, Python support as per numpy NEP 29, uses NumPy, SciPy, pytest.
Graph manipulation in Python, graph evaluation out of Python. Emulate NumPy user-facing API as much as possible.
- Lazy evaluation: Expressions are symbolic until
pytensor.function()compiles or.eval()evaluates - Pure semantics:
new_x = x[idx].set(y)instead ofx[idx] = y - Immutable/hashable: PyTensor variables are hashable.
a == btests identity (a is b), not elementwise equality. - Static shapes: Broadcasting requires static shape of 1. Valid:
pt.add(pt.vector("x", shape=(1,)), pt.vector("y")). Invalid:pt.add(pt.vector("x", shape=(None,)), pt.vector("y"))with x.shape=1. - Static rank and type. PyTensor functions accepts variables with a specific dtype and number of dimensions. Length of each dimension can be static or dynamic.
Uses pre-commit with ruff. Code should pass pre-commit before being committed.
Performance
- Could should be performant
- Avoid expensive work in hot loops
- Avoid redundant checks. Let errors raise naturally
- In contrast, silent errors should be prevented
Comments: Should be used sparingly, only for complex logic
Testing: Should be succinct
- Prefer
tests.unittest_tools.assert_equal_computationsover numerical evaluation - Test multiple inputs on one compiled function vs multiple compilations
- Minimize test conditions. Be smart, not fearful
- Integrate with similar existing tests
.github/(workflows),doc/(docs)pyproject.toml(config),setup.py(Cython build),conftest.py(pytest config),environment.yml(conda env)
configdefaults.py: Config system (floatX, mode)gradient.py: Auto-differentiationcompile/: Function compilationgraph/: IR and optimization (graph/rewriting/)link/: Backends (c/,jax/,numba/,mlx/,pytorch/)tensor/: Tensor ops (largest module, subdirs:random/,rewriting/,conv/)scalar/: Scalar opsscan/: Loop operations (scan_perform.pyxCython)sparse/: Sparse tensorsxtensor/Tensor Ops with dimensions (lowers to Tensor ops)
Mirrors source structure. unittest_tools.py has testing utilities.
python -m pytest tests/ # All tests
python -m pytest tests/test_updates.py -v # Single file
python -m pytest tests/ --runslow # Include slow testsTests are run with config.mode = "FAST_COMPILE". If needed:
- Cast numerical values
test_value.astype(symbolic_var.type.dtype) - Use custom function mode
get_default_mode().excluding("fusion")or skip tests inFAST_COMPILEif they are not directly relevant to the mode.
Alternative backends (JAX, PyTorch, MLX) are optional. Use pytest.importorskip to fail gracefully.
python ./scripts/run_mypy.py --verbosePyTensor incompatible with strict mypy. Type-hints are for users/developers not to appease mypy. Liberal type: ignore[rule] and file exclusions are acceptable.
python -m sphinx -b html ./doc ./html # Build docs (2-3 min)Never commit html directory.
Use pytensor.dprint to inspect graphs. It works on both raw variables (before optimization) and compiled functions (after optimization):
pytensor.dprint(y, print_type=True) # Before optimization
pytensor.dprint(f, print_type=True, print_memory_map=True) # After optimizationprint_type=True shows the type and shape of each variable. print_memory_map=True shows memory allocation labels, useful for spotting whether intermediates share memory.
Use rewrite_graph to apply rewrites to a graph without the full pytensor.function compilation:
from pytensor.graph.rewriting.utils import rewrite_graph
y_opt = rewrite_graph(y, include=("canonicalize", "specialize"))
pytensor.dprint(y_opt, print_type=True)Use optimizer_verbose=True to see which rewrites are applied during compilation:
with pytensor.config.change_flags(optimizer_verbose=True):
f = pytensor.function([x], y)This prints each rewrite that fires, showing what it replaced and with what.
- test.yml: Main suite - Several Python versions, 7 test parts + backend jobs (jax, torch)
- mypy.yml: Type checking
CI runs tests under three modes: Default ("NUMBA"), "CVM", and "FAST_COMPILE". Tests must pass in all three.
- Numba scalar outputs: Numba-compiled scalar functions return Python
float/int, not NumPy scalars. Keep this in mind when writing tests that check output types.