Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges#1349
Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges#1349dredozubov wants to merge 10 commits intok2-fsa:masterfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances k2's compatibility and performance on Apple Silicon Macs by introducing a robust Metal Performance Shaders (MPS) backend. It addresses the core challenge of PyTorch's Highlights
Changelog
Ignored Files
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive Apple Metal Performance Shaders (MPS) support to the k2 library, enabling GPU acceleration on Apple Silicon. Key changes include updating CMake to detect and configure MPS, adding a new mps_utils.h for MPS-accelerated ATen operations, and implementing a PytorchMpsContext for managing MPS memory and safe data transfers. Core k2 functionalities like mutual_information_recursion, GetForwardScores, IntersectDense, and IntersectDensePruned are extended with MPS-native implementations or CPU fallbacks with gradient bridging. The MPS_NATIVE_METAL_PLAN.md document details the phased rollout of native Metal kernels for performance-critical operations, with a progress tracker indicating completion of several priority items. A review comment highlights a potential data race in array_ops_inl.h where the CPU path accesses MPS memory without explicit synchronization, which could lead to stale data. Another comment points out that the existing code comment in array.h regarding unified memory is an oversimplification and should clarify the reliance on PyTorch's implicit synchronization for MPS tensors.
…dges, 65-test suite k2 had no MPS support: every operation fell through PyTorch's CPU fallback, making GPU-accelerated training impossible on Apple Silicon. This PR adds a complete MPS backend using two complementary strategies: 1. Native Metal kernels for hot paths — MSL compute shaders dispatched via PyTorch's existing MPS command buffer infrastructure, avoiding the K2_EVAL raw-pointer reads that crash on MTLStorageModePrivate buffers. 2. CPU bridges in the Python autograd layer — for topology-dependent algorithms that rely on K2_EVAL, transparently copy the FSA to CPU, run the existing C++ algorithm, and return a gradient-connected result to MPS via a lightweight _MpsScoresBridge autograd function. Changes: - mutual_information_mps.mm: native Metal MI kernel (antidiagonal wavefront, matching CUDA blocked pattern); dispatch_once pipeline cache; nil-function guards to prevent NSInvalidArgumentException on Metal compile failure - mps_fsa_scores.mm: native Metal GetForwardScores (BFS level-by-level dispatch, log + tropical semiring CAS loops with NaN guards); Hillis-Steele O(log N) associative scan for single-FSA tropical graphs with N in [4,128] (reduces Metal encoder calls from N to ceil(log2 N)) - mps_utils.h: int32-native ExclusiveSumMps, RowSplitsToRowIdsMps, MaxSizeMps using PyTorch >= 2.2 MPS int32 ops; zero-copy arc-batch helper - autograd.py: MPS CPU bridges for _GetForwardScoresFunction, _GetBackwardScoresFunction, _GetArcPostFunction, _IntersectDensePrunedFunction, _IntersectDenseFunction; stride-0 tensor fix in forward-scores backward path; dtype guard (use out_fsa_grad.dtype not fsa_grad_cpu.dtype) for fp16/bf16 - fsa.py: _MpsScoresBridge autograd function anchoring gradient graph for get_tot_scores; NotImplementedError guards for MPS-unsupported paths (use_double_scores, non-differentiable _get_backward_scores/_get_tot_scores) - CMakeLists.txt: K2_WITH_MPS flag, Metal+Foundation framework linkage, separate .mm from .cu before transform() to preserve Objective-C++ extension - test_mps.py: 65 MPS-specific tests covering MI, forward scores, tot scores, arc post, intersect dense, assoc scan, edge cases, numerical stress - .github/workflows/build-mps-macos.yml: CI on macos-14 (Apple Silicon), triggered on push/PR-labeled-ready/workflow_dispatch
b5fe638 to
cd5a81d
Compare
|
Re the two Gemini review comments:
|
…opyright - style_check.yml: bump Python 3.9→3.10 and actions to v4/v5 (3.9 no longer available on GitHub Actions runners) - array_ops_inl.h: add torch::mps::synchronize() before all four MPS→CPU raw-pointer reads (Append, ReverseCumSum min/max, Sum); matches pattern already used in array_ops.cu - array.h: clarify comment — MTLStorageModeShared makes data_ptr() valid on CPU, but synchronize() is still required to flush pending Metal writes - autograd.py, fsa.py, test_mps.py, version.py: fix all flake8 violations (E501 line length, E226 arithmetic spacing, E302 blank lines, F541 f-string) - Copyright year: update 2024→2026 in all new MPS source files
|
@csukuangfj @pkufool Can you guys take a look, please? I'd love to see it merged anytime soon. |
…ne states ARC would release id<MTLComputePipelineState> members during static/heap destructor teardown, after PyTorch had already shut down the Metal device. Using __unsafe_unretained opts those fields out of ARC so no Obj-C message is sent to deallocated Metal objects on interpreter exit.
…ters K2_EVAL on kMps runs the lambda on CPU using data_ptr() raw pointers. On MPS, calling data_ptr() on tensors that have pending Metal writes (e.g. from torch::rand_like) and writing back results races against the Metal command queue, causing undefined behaviour (segfault on exit 139). Add #ifdef K2_WITH_MPS guards to SwooshForward, SwooshForwardAndDeriv, and SwooshFunction::forward/backward that dispatch to numerically-equivalent ATen operations (torch::logaddexp, torch::sigmoid) which run natively on MPS without CPU pointer access: - SwooshForward: logaddexp(0, x-shift) - coeff*x - offset - SwooshForwardAndDeriv: same y, plus sigmoid(x-shift) - coeff for deriv - SwooshFunction/forward: ATen forward + save x (and Bernoulli mask for dropout) for the custom backward - SwooshFunction/backward: recompute gradient from saved x via sigmoid The 8-bit quantised gradient trick is unavailable on MPS; standard autograd handles differentiation with negligible memory overhead at the typical activation sizes where MPS is used. Fixes segfault reported by LuxTTS testers when calling swoosh_l_forward on MPS tensors.
Motivation
k2 is used for FSA-based speech recognition (CTC, RNN-T, MMI), and Apple Silicon Macs have become the primary development machine for a lot of researchers. The problem: k2 had no MPS support at all. Every operation fell through PyTorch's CPU fallback, so GPU-accelerated training on macOS simply didn't work.
Why this was hard
k2's C++ core uses the `K2_EVAL` macro extensively — it reads raw device pointers from the CPU side to walk arc data. On CUDA this works because pinned memory and unified addressing let the CPU see GPU memory. On MPS, `data_ptr()` on a tensor doesn't return a valid CPU pointer, so any path through `K2_EVAL` crashes immediately.
Two things fix this. For hot paths, native Metal kernels that receive `id` handles and never touch device memory from the CPU. For topology-dependent operations where native Metal isn't worth it yet, CPU bridges in the Python autograd layer: copy the FSA to CPU, run the existing C++ algorithm, return a gradient-connected result to MPS via a lightweight `_MpsScoresBridge` autograd function.
What changed
`mutual_information_mps.mm` — native Metal MI kernel
Replaces the CPU round-trip in `mutual_information_recursion` with a Metal compute kernel using the blocked antidiagonal wavefront pattern from the CUDA version. Forward (`mi_forward`) and backward (`mi_backward`) kernels are compiled at runtime from an embedded MSL string, cached via `dispatch_once`. Expected speedup: 5–8× over CPU bridging for typical (T=300, S=50) sequences.
CUDA→Metal primitive mapping:
`mps_utils.h` — ATen int32 ops
PyTorch MPS supports int32 `cumsum`, `diff`, `repeat_interleave`, and `searchsorted` natively since PyTorch 2.2. `ExclusiveSumMps`, `RowSplitsToRowIdsMps`, and `MaxSizeMps` are rewritten to use int32 throughout, removing redundant int64 casts and intermediate allocations.
`mps_fsa_scores.mm` — native Metal `GetForwardScores`
Level-by-level BFS dispatch: each BFS layer's arcs are independent (antecedent states are already scored), so one Metal command encoder per layer is dispatched onto PyTorch's existing MPS command buffer. Supports log semiring (CAS loop with `log_add`) and tropical semiring (CAS loop with `max`). NaN guards (`if (isnan(candidate)) return;`) prevent the CAS loops from spinning forever on invalid inputs — `NaN >= x` is always false in IEEE 754, so without the guard the loop never exits.
`fsa.py` — zero-copy arc-batch computation
A full `self.to('cpu')` copies 16 bytes × num_arcs plus all tensor attributes. The BFS-level sort only needs the `dest_state` column (4 bytes × num_arcs). `_compute_bfs_arc_batches_mps` extracts only that column, sorts arc indices on CPU via `torch.argsort`, and returns `sorted_arc_ids` as an MPS tensor — 4× less data crossing the bus.
`autograd.py` — MPS bridges for `IntersectDense` and `IntersectDensePruned`
`_IntersectDensePrunedFunction` and `_IntersectDenseFunction` run the forward pass on CPU (the intersection algorithm is topology-dependent and relies on `K2_EVAL`). Gradients are scattered via `_k2.index_add` on CPU and moved back to MPS. `_MpsScoresBridge` keeps the gradient graph intact through the device transfer.
`mps_fsa_scores.mm` — Hillis-Steele O(log N) associative scan
For single-FSA, tropical-semiring graphs with 4 ≤ N ≤ 128 states, the sequential N-step BFS dispatch is replaced by a Hillis-Steele prefix scan over N per-state transition matrices:
Additional autograd bridges
Auditing `K2_EVAL` call paths turned up three more autograd functions that would crash on MPS:
Build system
CI (`.github/workflows/build-mps-macos.yml`)
Runs on `macos-14` (Apple Silicon, MPS hardware available). Triggers on push to master (path-filtered), PR labeled `ready`, or `workflow_dispatch`. Steps:
Test coverage (65 tests, all passing)
Known limitations
Test plan