Skip to content

Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges#1349

Open
dredozubov wants to merge 10 commits intok2-fsa:masterfrom
dredozubov:feature/mps-backend
Open

Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges#1349
dredozubov wants to merge 10 commits intok2-fsa:masterfrom
dredozubov:feature/mps-backend

Conversation

@dredozubov
Copy link
Copy Markdown

@dredozubov dredozubov commented Mar 14, 2026

Motivation

k2 is used for FSA-based speech recognition (CTC, RNN-T, MMI), and Apple Silicon Macs have become the primary development machine for a lot of researchers. The problem: k2 had no MPS support at all. Every operation fell through PyTorch's CPU fallback, so GPU-accelerated training on macOS simply didn't work.

Why this was hard

k2's C++ core uses the `K2_EVAL` macro extensively — it reads raw device pointers from the CPU side to walk arc data. On CUDA this works because pinned memory and unified addressing let the CPU see GPU memory. On MPS, `data_ptr()` on a tensor doesn't return a valid CPU pointer, so any path through `K2_EVAL` crashes immediately.

Two things fix this. For hot paths, native Metal kernels that receive `id` handles and never touch device memory from the CPU. For topology-dependent operations where native Metal isn't worth it yet, CPU bridges in the Python autograd layer: copy the FSA to CPU, run the existing C++ algorithm, return a gradient-connected result to MPS via a lightweight `_MpsScoresBridge` autograd function.

What changed

`mutual_information_mps.mm` — native Metal MI kernel

Replaces the CPU round-trip in `mutual_information_recursion` with a Metal compute kernel using the blocked antidiagonal wavefront pattern from the CUDA version. Forward (`mi_forward`) and backward (`mi_backward`) kernels are compiled at runtime from an embedded MSL string, cached via `dispatch_once`. Expected speedup: 5–8× over CPU bridging for typical (T=300, S=50) sequences.

CUDA→Metal primitive mapping:

  • `shared float arr[N]` → `threadgroup float arr[N]`
  • `__syncthreads()` → `threadgroup_barrier(mem_flags::mem_threadgroup)`
  • `__syncwarp()` → `simdgroup_barrier(mem_flags::mem_threadgroup)`

`mps_utils.h` — ATen int32 ops

PyTorch MPS supports int32 `cumsum`, `diff`, `repeat_interleave`, and `searchsorted` natively since PyTorch 2.2. `ExclusiveSumMps`, `RowSplitsToRowIdsMps`, and `MaxSizeMps` are rewritten to use int32 throughout, removing redundant int64 casts and intermediate allocations.

`mps_fsa_scores.mm` — native Metal `GetForwardScores`

Level-by-level BFS dispatch: each BFS layer's arcs are independent (antecedent states are already scored), so one Metal command encoder per layer is dispatched onto PyTorch's existing MPS command buffer. Supports log semiring (CAS loop with `log_add`) and tropical semiring (CAS loop with `max`). NaN guards (`if (isnan(candidate)) return;`) prevent the CAS loops from spinning forever on invalid inputs — `NaN >= x` is always false in IEEE 754, so without the guard the loop never exits.

`fsa.py` — zero-copy arc-batch computation

A full `self.to('cpu')` copies 16 bytes × num_arcs plus all tensor attributes. The BFS-level sort only needs the `dest_state` column (4 bytes × num_arcs). `_compute_bfs_arc_batches_mps` extracts only that column, sorts arc indices on CPU via `torch.argsort`, and returns `sorted_arc_ids` as an MPS tensor — 4× less data crossing the bus.

`autograd.py` — MPS bridges for `IntersectDense` and `IntersectDensePruned`

`_IntersectDensePrunedFunction` and `_IntersectDenseFunction` run the forward pass on CPU (the intersection algorithm is topology-dependent and relies on `K2_EVAL`). Gradients are scattered via `_k2.index_add` on CPU and moved back to MPS. `_MpsScoresBridge` keeps the gradient graph intact through the device transfer.

`mps_fsa_scores.mm` — Hillis-Steele O(log N) associative scan

For single-FSA, tropical-semiring graphs with 4 ≤ N ≤ 128 states, the sequential N-step BFS dispatch is replaced by a Hillis-Steele prefix scan over N per-state transition matrices:

  • Each state `d` has an N×N transfer matrix `M[d]`: diagonal entries are 0, `M[d][d][src]` = max arc weight from `src` to `d`, all else −∞.
  • `M[0]` = identity (start state, no entering arcs).
  • Prefix product `P[s] = M[s] ⊗ M[s-1] ⊗ … ⊗ M[0]` (tropical matrix-multiply).
  • Forward score of state `s` = `P[s][s][0]` (best path from state 0 to state s).
  • Hillis-Steele doubling reduces Metal encoder calls from N to ⌈log₂N⌉.
  • Power-of-2 padding uses identity matrices (neutral element for tropical matmul).

Additional autograd bridges

Auditing `K2_EVAL` call paths turned up three more autograd functions that would crash on MPS:

  • `_GetForwardScoresFunction.backward` — bridges to CPU; also fixes a stride-0 tensor issue in the `.expand_as().contiguous()` chain
  • `_GetBackwardScoresFunction` (forward + backward) — makes `get_backward_scores` and `get_arc_post` work on MPS
  • `_GetArcPostFunction` (forward + backward) — full MPS support for `get_arc_post`

Build system

  • `CMakeLists.txt`: detect MPS via Python/PyTorch probe; add `K2_WITH_MPS` option; link `Metal.framework` + `Foundation.framework`
  • `k2/python/csrc/CMakeLists.txt`: separate `.mm` files from `.cu` before `transform()` — the transform renames source extensions, and `.mm` → `.cc` breaks Objective-C++ compilation
  • `k2/python/csrc/torch/CMakeLists.txt`: conditionally add `.mm` sources under `K2_WITH_MPS`

CI (`.github/workflows/build-mps-macos.yml`)

Runs on `macos-14` (Apple Silicon, MPS hardware available). Triggers on push to master (path-filtered), PR labeled `ready`, or `workflow_dispatch`. Steps:

  1. Build wheel with `K2_WITH_MPS=ON`
  2. Verify `k2.with_mps == True`
  3. Run `test_mps.py` (65 MPS-specific tests)
  4. Run all other tests as CPU sanity check
  5. Build + ctest C++ tests

Test coverage (65 tests, all passing)

Class Tests Coverage
`TestMpsTensor` 6 FSA/tensor transfer, flag, basic ops
`TestMpsMutualInformation` 4 Forward parity, backward, boundary tensor
`TestMpsTotScores` 5 Log/tropical backward, double, non-unit grad
`TestMpsForwardScores` 4 Log/tropical parity, differentiable, multi-FSA
`TestMpsIntersectDense` 5 Forward/backward parity, seqframe attr (pruned + non-pruned)
`TestMpsAssocScan` 5 Linear/branching parity, tot_scores, large fallback
`TestMpsEdgeCases` 6 Empty FSA, unreachable state, NotImplementedError guards, NaN-free grad
`TestMpsAssocScanBoundaries` 8 N=4/5/7/16/128/129, diamond topology, multi-arc dest
`TestMpsArcPost` 4 Tropical/log parity + gradient
`TestMpsGetForwardScoresDifferentiable` 3 Tropical/log/non-unit gradient
`TestMpsIntersectDenseExtended` 5 2-utterance batch, seqframe parity (fixed seed)
`TestMpsMutualInformationExtended` 4 Varied sizes, gradient, batch, boundary
`TestMpsForwardScoresNumericalStress` 3 50-state random DAG, NaN-free gradient

Known limitations

  • `_get_forward_scores(use_double_scores=True)` on MPS raises `NotImplementedError` — MPS has no float64. Use `get_forward_scores` (differentiable, float32) instead.
  • `_get_backward_scores` and `_get_tot_scores` (non-differentiable) on MPS raise `NotImplementedError`. Use the differentiable `get_backward_scores` / `get_tot_scores` instead.
  • The O(log N) assoc-scan path only kicks in for single-FSA, tropical semiring, N in [4, 128]. Anything outside that falls back silently to the sequential Metal path.
  • `IntersectDense` / `IntersectDensePruned` forward+backward run on CPU. Only the score tensors are MPS-resident.

Test plan

  • 65/65 MPS tests pass locally on M-series Mac
  • CPU tests unmodified and unaffected
  • CI run on `macos-14` via `build-mps-macos.yml` (add `ready` label to trigger)

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances k2's compatibility and performance on Apple Silicon Macs by introducing a robust Metal Performance Shaders (MPS) backend. It addresses the core challenge of PyTorch's MTLStorageModePrivate buffers by implementing a hybrid approach that combines native Metal kernels for critical computations with intelligent CPU bridging for operations that are not yet Metal-optimized. This integration ensures that k2 can leverage Apple's GPU capabilities, providing substantial speedups for key algorithms while maintaining correctness and a seamless developer experience.

Highlights

  • Native Metal Kernels for Core Operations: Implemented native Metal compute kernels for performance-critical operations like mutual_information and GetForwardScores, replacing previous CPU fallbacks. This includes a blocked antidiagonal wavefront pattern for mutual_information and a level-by-level BFS dispatch for GetForwardScores.
  • Autograd CPU Bridging for MPS: Introduced a hybrid strategy where complex C++ algorithms that rely on raw CPU pointer access (e.g., IntersectDense, GetBackwardScores, GetArcPost) are transparently bridged to CPU, executed, and their gradients are reconnected to MPS via a new _MpsScoresBridge autograd function.
  • Optimized ATen Operations for MPS: Rewrote several utility functions (ExclusiveSumMps, RowSplitsToRowIdsMps, MaxSizeMps, GetCountsMps) to leverage native PyTorch MPS ATen operations, eliminating redundant int64 casts and intermediate allocations for improved performance.
  • Zero-Copy Arc-Batch Computation: Developed a zero-copy mechanism for _compute_bfs_arc_batches_mps that transfers only the dest_state column to CPU for sorting, significantly reducing data transfer across the bus compared to full FSA copies.
  • Hillis-Steele Associative Scan for Forward Scores: Integrated an O(log N) Hillis-Steele prefix scan for single-FSA, tropical-semiring graphs within a specific state range (4 ≤ N ≤ 128), reducing Metal encoder calls for GetForwardScores.
  • Comprehensive Build System and CI Integration: Updated the CMake build system to detect MPS availability, conditionally compile Objective-C++ (.mm) files, and link necessary Metal frameworks. A new CI workflow (build-mps-macos.yml) was added to ensure continuous validation on macOS with MPS hardware.
  • Extensive Test Suite: Added a comprehensive suite of 65 MPS-specific tests covering context, training, forward scores, intersect dense, associative scan, edge cases, arc post, and numerical stress, all of which are currently passing.
Changelog
  • .gitignore
    • Added '.claude/' to the ignore list.
  • CMakeLists.txt
    • Implemented detection for MPS availability on macOS using a Python/PyTorch probe.
    • Added a K2_WITH_MPS CMake option and definition to control MPS-specific compilation.
    • Ensured Python3 Development component is found before pybind11 to support python3_add_library.
  • MPS_NATIVE_METAL_PLAN.md
    • Added a new markdown document outlining the detailed plan for native Metal acceleration, including kernel designs and integration points.
  • MPS_PLAN.md
    • Added a new markdown document detailing the implementation plan for Apple Metal/MPS support, covering context plumbing, CPU fallback, Python layer integration, and ATen-op promotion.
  • k2/csrc/CMakeLists.txt
    • Added linking for Metal.framework and Foundation.framework on Apple platforms when MPS is enabled.
    • Included a linker flag (-undefined dynamic_lookup) for macOS to resolve undefined symbols from libtorch_python.dylib at runtime.
  • k2/csrc/array.h
    • Modified Data() accessor to allow CPU-side access for MPS tensors, leveraging Apple Silicon's unified memory architecture.
  • k2/csrc/array_ops.cu
    • Included mps_utils.h for MPS-specific utility functions.
    • Updated SpliceRowSplits, CatWithOffsets, GetCounts, and SizesToMergeMap to include MPS-specific logic, often synchronizing the Metal command queue before CPU operations.
  • k2/csrc/array_ops_inl.h
    • Extended CPU fallback logic in Transpose, Cat, MonotonicLowerBound, MonotonicDecreasingUpperBound, and Sum to also apply to MPS devices.
  • k2/csrc/context.cu
    • Updated ParallelRunnerActive::NewStream to return kCudaStreamInvalid for MPS contexts, ensuring CPU fallback for parallel runners.
  • k2/csrc/context.h
    • Added kMps to the DeviceType enum and its operator<< overload for debugging.
    • Declared GetMpsContext() for retrieving an MPS context.
  • k2/csrc/device_guard.h
    • Added a comment clarifying that CPU and MPS contexts do not have a CUDA device to guard.
  • k2/csrc/macros.h
    • Modified K2_EVAL, K2_EVAL2, and K2_TRANS_EXCSUM macros to include kMps in their CPU fallback paths.
  • k2/csrc/mps_utils.h
    • Added a new header file defining MPS-specific utility functions (AsMpsTensor, ExclusiveSumMps, InclusiveSumMps, RowSplitsToRowIdsMps, RowIdsToRowSplitsMps, MaxSizeMps, GetCountsMps) that utilize PyTorch's ATen ops for Metal acceleration.
  • k2/csrc/pinned_context.cu
    • Added a case for kMps in GetContextForTransfer() to return a CPU context, leveraging unified memory for staging transfers.
  • k2/csrc/pytorch_context.cu
    • Introduced a global g_mps_registry and MpsRegistryView to track MPS tensors allocated by k2, enabling Metal-safe copy_() operations.
    • Implemented the PytorchMpsContext class to manage MPS memory allocation, deallocation, and data transfers.
    • Updated PytorchCpuContext::CopyDataTo to handle CPU to MPS copies using the global registry for safety.
    • Modified NewRegion to correctly identify and create contexts for MPS tensors.
  • k2/csrc/pytorch_context.h
    • Declared GetMpsContext() and MpsRegistryView() for MPS context and tensor management.
  • k2/csrc/ragged.cu
    • Integrated mps_ops::MaxSizeMps for MPS devices in RaggedShape::MaxSize.
  • k2/csrc/torch_util.cu
    • Updated ToTorchDeviceType, FromTorchDeviceType, ToTorch(Array1<Arc>), and GetContext to properly handle and dispatch for MPS devices.
  • k2/csrc/torch_util.h
    • Modified ToTorch(Array1<T>) to use a proper view of ManagedTensor for MPS, avoiding issues with torch::from_blob().
  • k2/csrc/utils.cu
    • Included mps_utils.h and added MPS-specific implementations for RowSplitsToRowIds and RowIdsToRowSplits.
  • k2/csrc/utils_inl.h
    • Included mps_utils.h and added MPS-specific implementations for ExclusiveSum, InclusiveSum, and MaxValue using ATen ops.
  • k2/csrc/version.h.in
    • Added a K2_WITH_MPS definition to track MPS compilation status.
  • k2/python/csrc/CMakeLists.txt
    • Modified the transform() command to correctly handle Objective-C++ (.mm) files, preventing their renaming to .cc.
    • Added conditional linking of Metal and Foundation frameworks for MPS builds on Apple.
  • k2/python/csrc/torch.h
    • Updated the To template function to support MPS devices, allowing objects to be moved to 'mps' device type.
  • k2/python/csrc/torch/CMakeLists.txt
    • Conditionally added mutual_information_mps.mm and mps_fsa_scores.mm to the build when K2_WITH_MPS is enabled.
  • k2/python/csrc/torch/arc.cu
    • Modified ToTorch for Arc to use tensor.view(scalar_type) instead of torch::from_blob for MPS compatibility, ensuring proper storage sharing.
  • k2/python/csrc/torch/fsa.cu
    • Included mps_fsa_scores.h and added Python bindings for MPS-native forward score computations: get_forward_scores_mps, get_forward_scores_mps_native, and get_forward_scores_mps_assoc_scan.
  • k2/python/csrc/torch/mps_fsa_scores.h
    • Added a new header file declaring MPS-accelerated GetForwardScoresMps, GetForwardScoresMpsNative, and GetForwardScoresMpsAssocScan functions.
  • k2/python/csrc/torch/mps_fsa_scores.mm
    • Added a new Objective-C++ file containing embedded Metal Shading Language (MSL) kernels for fsa_forward_log, fsa_forward_tropical, and associative scan operations, along with their C++ wrappers and pipeline management.
  • k2/python/csrc/torch/mutual_information.cu
    • Implemented MPS-specific dispatch for MutualInformationCpu and MutualInformationBackwardCpu, including CPU fallback for double-precision floating-point numbers not natively supported by MPS.
  • k2/python/csrc/torch/mutual_information.h
    • Declared MutualInformationMps and MutualInformationBackwardMps functions for MPS-accelerated mutual information calculations.
  • k2/python/csrc/torch/pruned_ranges_to_lattice.cu
    • Added conditional MPS support to PrunedRangesToLattice to leverage MPS contexts.
  • k2/python/csrc/torch/v2/ragged_any.cu
    • Extended RaggedAny::To method to support moving RaggedAny objects to MPS devices.
  • k2/python/csrc/torch/v2/ragged_shape.cu
    • Updated PybindRaggedShape to include support for moving RaggedShape objects to MPS devices.
  • k2/python/csrc/version.cu
    • Added with_mps attribute to the k2 version information.
  • k2/python/k2/init.py
    • Imported with_mps from _k2.version.
    • Updated __dev_version__.
  • k2/python/k2/autograd.py
    • Introduced _MpsScoresBridge as a new torch.autograd.Function to facilitate gradient-connected data transfer between MPS and CPU.
    • Modified _GetTotScoresFunction, _GetForwardScoresFunction, _GetBackwardScoresFunction, _GetArcPostFunction, _IntersectDensePrunedFunction, and _IntersectDenseFunction to incorporate CPU bridging for MPS devices, ensuring correct gradient flow.
  • k2/python/k2/dense_fsa_vec.py
    • Updated the to() method to explicitly support 'mps' as a target device type.
  • k2/python/k2/fsa.py
    • Added _compute_bfs_arc_batches_mps for efficient, zero-copy BFS arc batch computation on MPS.
    • Updated _get_forward_scores to dispatch to MPS-native or associative scan paths when applicable, and to raise NotImplementedError for double-precision on MPS.
    • Added NotImplementedError for direct calls to _get_tot_scores and _get_backward_scores on MPS, guiding users to differentiable alternatives.
    • Modified get_tot_scores to use _MpsScoresBridge for MPS devices, enabling CPU fallback with proper gradient handling.
    • Updated the to() method to explicitly support 'mps' as a target device type.
  • k2/python/k2/version/version.py
    • Added with_mps to the printed version information.
  • k2/python/tests/test_mps.py
    • Added a new comprehensive test file (test_mps.py) containing 65 MPS-specific tests to validate the new backend's functionality, correctness, and performance across various k2 operations.
Ignored Files
  • Ignored by pattern: .github/workflows/** (1)
    • .github/workflows/build-mps-macos.yml
Activity
  • The pull request introduces a complete Apple MPS backend for k2, enabling GPU-accelerated training on macOS.
  • Initial implementation focused on CPU bridging for complex operations, followed by native Metal kernel development for hot paths.
  • A new CI workflow (.github/workflows/build-mps-macos.yml) has been set up to build and test the MPS backend on macos-14 runners.
  • Extensive test coverage has been added, with 65 new MPS-specific tests, all of which are currently passing, ensuring the correctness and reliability of the new backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive Apple Metal Performance Shaders (MPS) support to the k2 library, enabling GPU acceleration on Apple Silicon. Key changes include updating CMake to detect and configure MPS, adding a new mps_utils.h for MPS-accelerated ATen operations, and implementing a PytorchMpsContext for managing MPS memory and safe data transfers. Core k2 functionalities like mutual_information_recursion, GetForwardScores, IntersectDense, and IntersectDensePruned are extended with MPS-native implementations or CPU fallbacks with gradient bridging. The MPS_NATIVE_METAL_PLAN.md document details the phased rollout of native Metal kernels for performance-critical operations, with a progress tracker indicating completion of several priority items. A review comment highlights a potential data race in array_ops_inl.h where the CPU path accesses MPS memory without explicit synchronization, which could lead to stale data. Another comment points out that the existing code comment in array.h regarding unified memory is an oversimplification and should clarify the reliance on PyTorch's implicit synchronization for MPS tensors.

…dges, 65-test suite

k2 had no MPS support: every operation fell through PyTorch's CPU fallback,
making GPU-accelerated training impossible on Apple Silicon. This PR adds a
complete MPS backend using two complementary strategies:

1. Native Metal kernels for hot paths — MSL compute shaders dispatched via
   PyTorch's existing MPS command buffer infrastructure, avoiding the K2_EVAL
   raw-pointer reads that crash on MTLStorageModePrivate buffers.

2. CPU bridges in the Python autograd layer — for topology-dependent algorithms
   that rely on K2_EVAL, transparently copy the FSA to CPU, run the existing
   C++ algorithm, and return a gradient-connected result to MPS via a lightweight
   _MpsScoresBridge autograd function.

Changes:
- mutual_information_mps.mm: native Metal MI kernel (antidiagonal wavefront,
  matching CUDA blocked pattern); dispatch_once pipeline cache; nil-function
  guards to prevent NSInvalidArgumentException on Metal compile failure
- mps_fsa_scores.mm: native Metal GetForwardScores (BFS level-by-level dispatch,
  log + tropical semiring CAS loops with NaN guards); Hillis-Steele O(log N)
  associative scan for single-FSA tropical graphs with N in [4,128] (reduces
  Metal encoder calls from N to ceil(log2 N))
- mps_utils.h: int32-native ExclusiveSumMps, RowSplitsToRowIdsMps, MaxSizeMps
  using PyTorch >= 2.2 MPS int32 ops; zero-copy arc-batch helper
- autograd.py: MPS CPU bridges for _GetForwardScoresFunction,
  _GetBackwardScoresFunction, _GetArcPostFunction, _IntersectDensePrunedFunction,
  _IntersectDenseFunction; stride-0 tensor fix in forward-scores backward path;
  dtype guard (use out_fsa_grad.dtype not fsa_grad_cpu.dtype) for fp16/bf16
- fsa.py: _MpsScoresBridge autograd function anchoring gradient graph for
  get_tot_scores; NotImplementedError guards for MPS-unsupported paths
  (use_double_scores, non-differentiable _get_backward_scores/_get_tot_scores)
- CMakeLists.txt: K2_WITH_MPS flag, Metal+Foundation framework linkage,
  separate .mm from .cu before transform() to preserve Objective-C++ extension
- test_mps.py: 65 MPS-specific tests covering MI, forward scores, tot scores,
  arc post, intersect dense, assoc scan, edge cases, numerical stress
- .github/workflows/build-mps-macos.yml: CI on macos-14 (Apple Silicon),
  triggered on push/PR-labeled-ready/workflow_dispatch
@dredozubov dredozubov force-pushed the feature/mps-backend branch from b5fe638 to cd5a81d Compare March 14, 2026 20:57
@dredozubov
Copy link
Copy Markdown
Author

Re the two Gemini review comments:

array_ops_inl.h sync (HIGH) — agreed, fixing. Adding torch::mps::synchronize() before each of the four MPS→CPU read paths in array_ops_inl.h (lines 346, 743, 782, 1041), matching the pattern already in array_ops.cu.

array.h comment (MEDIUM) — partial disagreement on the framing. PyTorch's MPS allocator uses MTLStorageModeShared (not MTLStorageModePrivate) for standard tensors, so data_ptr() does return a valid CPU-dereferenceable pointer — the "unified memory" statement is accurate. The real risk is stale data from uncommitted Metal commands, which is a synchronization issue, not a pointer-validity issue. Updating the comment to clarify both points.

…opyright

- style_check.yml: bump Python 3.9→3.10 and actions to v4/v5 (3.9 no longer
  available on GitHub Actions runners)
- array_ops_inl.h: add torch::mps::synchronize() before all four MPS→CPU
  raw-pointer reads (Append, ReverseCumSum min/max, Sum); matches pattern
  already used in array_ops.cu
- array.h: clarify comment — MTLStorageModeShared makes data_ptr() valid on
  CPU, but synchronize() is still required to flush pending Metal writes
- autograd.py, fsa.py, test_mps.py, version.py: fix all flake8 violations
  (E501 line length, E226 arithmetic spacing, E302 blank lines, F541 f-string)
- Copyright year: update 2024→2026 in all new MPS source files
@dredozubov dredozubov changed the title Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges, 65-test suite Add Apple MPS (Metal) backend: native Metal kernels, autograd CPU bridges Mar 14, 2026
@dredozubov
Copy link
Copy Markdown
Author

@csukuangfj @pkufool Can you guys take a look, please? I'd love to see it merged anytime soon.

…ne states

ARC would release id<MTLComputePipelineState> members during static/heap
destructor teardown, after PyTorch had already shut down the Metal device.
Using __unsafe_unretained opts those fields out of ARC so no Obj-C message
is sent to deallocated Metal objects on interpreter exit.
…ters

K2_EVAL on kMps runs the lambda on CPU using data_ptr() raw pointers.
On MPS, calling data_ptr() on tensors that have pending Metal writes (e.g.
from torch::rand_like) and writing back results races against the Metal
command queue, causing undefined behaviour (segfault on exit 139).

Add #ifdef K2_WITH_MPS guards to SwooshForward, SwooshForwardAndDeriv,
and SwooshFunction::forward/backward that dispatch to numerically-equivalent
ATen operations (torch::logaddexp, torch::sigmoid) which run natively on MPS
without CPU pointer access:

- SwooshForward: logaddexp(0, x-shift) - coeff*x - offset
- SwooshForwardAndDeriv: same y, plus sigmoid(x-shift) - coeff for deriv
- SwooshFunction/forward: ATen forward + save x (and Bernoulli mask for
  dropout) for the custom backward
- SwooshFunction/backward: recompute gradient from saved x via sigmoid

The 8-bit quantised gradient trick is unavailable on MPS; standard
autograd handles differentiation with negligible memory overhead at the
typical activation sizes where MPS is used.

Fixes segfault reported by LuxTTS testers when calling swoosh_l_forward
on MPS tensors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant