A custom torch.compile backend that routes PyTorch model execution through
your ONNX hardware runtime (OMRun).
model.generate() [Python loop — NOT compiled]
│
├── prefill: model.forward(full_prompt) → torch.compile → ONNX → OMRun()
├── decode 1: model.forward(token) → torch.compile → ONNX → OMRun()
├── decode 2: model.forward(token) → torch.compile → ONNX → OMRun()
└── ...
- TorchDynamo captures the forward pass as FX subgraphs
- OMRun backend exports each subgraph to
.onnxand caches it - At runtime, tensors are converted to numpy, passed to
OMRun(), and results converted back
pip install -r requirements.txt# Default prompt, greedy decoding
python gpt2_generate.py
# Custom prompt
python gpt2_generate.py --prompt "Once upon a time"
# Sampling with temperature
python gpt2_generate.py --prompt "Hello world" --max-new-tokens 100 --do-sample --temperature 0.8Edit the OMRun function in gpt2_generate.py:
def OMRun(model_path: str, inputs: dict) -> list:
"""
Args:
model_path: Path to .onnx file
inputs: Dict of {name: np.ndarray}
Returns:
List of np.ndarray (one per output)
"""
return my_hardware_lib.run(model_path, inputs)from omrun_backend import make_omrun_backend
make_omrun_backend(om_run_fn=OMRun, device="cpu")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.forward = torch.compile(model.forward, backend="omrun")
model.generation_config.cache_implementation = "static"
output = model.generate(**inputs, max_new_tokens=50)om-backend/
├── omrun_backend.py # The torch.compile backend (reusable library)
├── gpt2_generate.py # GPT-2 generation example
├── requirements.txt
└── README.md