diff --git a/.github/workflows/build-mps-macos.yml b/.github/workflows/build-mps-macos.yml new file mode 100644 index 000000000..cda2ba2f7 --- /dev/null +++ b/.github/workflows/build-mps-macos.yml @@ -0,0 +1,130 @@ +# Copyright 2024 k2 contributors + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: build-mps-macos + +on: + push: + branches: + - master + paths: + - '.github/workflows/build-mps-macos.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'k2/csrc/**' + - 'k2/python/**' + pull_request: + types: [labeled] + paths: + - '.github/workflows/build-mps-macos.yml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'k2/csrc/**' + - 'k2/python/**' + workflow_dispatch: + +concurrency: + group: build-mps-macos-${{ github.ref }} + cancel-in-progress: true + +env: + BUILD_TYPE: Release + +jobs: + build-mps-macos: + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'workflow_dispatch' + name: Build (macOS, CPU+MPS) + runs-on: macos-14 # Apple Silicon M1/M2 — MPS is available + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Display clang version + run: clang --version + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Display Python version + run: python -c "import sys; print(sys.version)" + + - name: Install PyTorch (macOS arm64 with MPS) + shell: bash + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine + python3 -m pip install -qq torch --index-url https://download.pytorch.org/whl/cpu + python3 -c "import torch; print('torch version:', torch.__version__); print('MPS available:', torch.backends.mps.is_available())" + + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF -DK2_WITH_MPS=ON" + export K2_MAKE_ARGS="-j3" + python3 setup.py bdist_wheel + ls -lh dist/ + ls -lh build/* + + - name: Upload wheel + uses: actions/upload-artifact@v4 + with: + name: k2-macos14-mps-python3.11 + path: dist/*.whl + + - name: Install wheel and verify MPS support + shell: bash + run: | + pip install dist/*.whl + python3 -m k2.version + python3 -c " + import k2 + assert k2.with_mps, 'k2 was not built with MPS support' + print('k2.with_cuda:', k2.with_cuda) + print('k2.with_mps:', k2.with_mps) + print('MPS build verified.') + " + + - name: Run MPS tests + shell: bash + env: + PYTORCH_ENABLE_MPS_FALLBACK: "1" + run: | + pip install pytest + python3 -m pytest k2/python/tests/test_mps.py -v --tb=short + + - name: Run CPU tests (sanity check) + shell: bash + run: | + python3 -m pytest k2/python/tests/ -v --ignore=k2/python/tests/test_mps.py -x --tb=short -q + + - name: Build k2 (cmake + C++ tests) + shell: bash + run: | + mkdir -p build_cmake + cd build_cmake + cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF -DK2_WITH_MPS=ON .. + cat k2/csrc/version.h + make VERBOSE=1 -j3 + + - name: Run C++ tests (ctest) + shell: bash + run: | + cd build_cmake + ctest --output-on-failure diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 1dd6a8de6..9277f5ee9 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -42,23 +42,23 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: ["3.10"] fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 2 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install Python dependencies run: | python3 -m pip install --upgrade pip typing_extensions - python3 -m pip install --upgrade flake8==3.8.3 + python3 -m pip install --upgrade flake8 - name: Run flake8 shell: bash diff --git a/.gitignore b/.gitignore index 497b167b0..940061cfa 100644 --- a/.gitignore +++ b/.gitignore @@ -587,3 +587,4 @@ dkms.conf !.github/** !k2/torch/bin *-bak +.claude/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d2eb41ca..f3596b79f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,34 @@ if(APPLE OR (DEFINED K2_WITH_CUDA AND NOT K2_WITH_CUDA)) endif() endif() +# Propagate PYTHON_EXECUTABLE → Python3_EXECUTABLE so all find_package(Python3) +# calls use the same interpreter (e.g. the active venv) rather than the system +# Python located by CMake's automatic search. +if(DEFINED PYTHON_EXECUTABLE AND NOT DEFINED Python3_EXECUTABLE) + set(Python3_EXECUTABLE "${PYTHON_EXECUTABLE}" CACHE FILEPATH + "Path to the Python 3 interpreter" FORCE) +endif() + +# Detect MPS (Metal Performance Shaders) availability on macOS +set(_K2_WITH_MPS OFF) +if(APPLE) + find_package(Python3 QUIET COMPONENTS Interpreter) + if(Python3_FOUND) + execute_process( + COMMAND ${Python3_EXECUTABLE} -c + "import torch; print(torch.backends.mps.is_available())" + OUTPUT_VARIABLE _K2_MPS_AVAILABLE + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET) + if(_K2_MPS_AVAILABLE STREQUAL "True") + message(STATUS "MPS is available -- enabling K2_WITH_MPS") + set(_K2_WITH_MPS ON) + else() + message(STATUS "MPS is NOT available on this machine") + endif() + endif() +endif() + if(_K2_WITH_CUDA) set(languages ${languages} CUDA) endif() @@ -81,6 +109,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared or static lib" ON) option(K2_USE_PYTORCH "Whether to build with PyTorch" ON) option(K2_ENABLE_BENCHMARK "Whether to enable benchmark" ON) option(K2_WITH_CUDA "Whether to build k2 with CUDA" ${_K2_WITH_CUDA}) +option(K2_WITH_MPS "Whether to build k2 with MPS (Apple Metal)" ${_K2_WITH_MPS}) option(K2_ENABLE_NVTX "Whether to build k2 with the NVTX library" ON) option(K2_ENABLE_TESTS "Whether to build tests" ON) @@ -334,6 +363,11 @@ endif() list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) +# Ensure Python3 Development component is found before pybind11, so that +# python3_add_library is available (required by pybind11 NewTools with CMake 4+). +if(NOT Python3_Development_FOUND AND NOT Python3_Development.Module_FOUND) + find_package(Python3 REQUIRED COMPONENTS Interpreter Development) +endif() include(pybind11) if(K2_USE_PYTORCH) @@ -350,6 +384,10 @@ if(K2_WITH_CUDA) add_definitions(-DK2_WITH_CUDA) endif() +if(K2_WITH_MPS) + add_definitions(-DK2_WITH_MPS) +endif() + if(WIN32) add_definitions(-DNOMINMAX) # Otherwise, std::max() and std::min() won't work endif() diff --git a/README.md b/README.md index 1fd10b4b8..4e6f548f9 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,29 @@ LF-MMI training. This won't give a direct advantage in terms of Word Error Rate compared with existing technology; but the point is to do this in a much more general and extensible framework to allow further development of ASR technology. +## Apple Silicon (MPS) + +k2 supports Apple Silicon (M-series) via PyTorch's Metal Performance Shaders +(MPS) backend. To build with MPS enabled: + +```bash +git clone https://github.com/k2-fsa/k2.git +cd k2 +export K2_CMAKE_ARGS="-DK2_WITH_MPS=ON" +python3 setup.py install +``` + +You can verify MPS support is active with: + +```python +import k2 +print(k2.with_mps) # True +``` + +Hot paths (mutual information recursion, forward scores, associative scan) use +native Metal kernels. Topology-dependent operations run on CPU and return +MPS-resident tensors with a gradient-connected result. + ## Implementation A few key points on our implementation strategy. diff --git a/docs/source/index.rst b/docs/source/index.rst index e77ce7105..4580f690d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -19,8 +19,8 @@ k2 k2 is able to seamlessly integrate Finite State Automaton (FSA) and Finite State Transducer (FST) algorithms into autograd-based machine learning toolkits like PyTorch [#f1]_. -k2 supports CPU as well as CUDA. It can process a batch of FSTs -at the same time. +k2 supports CPU, CUDA, and Apple Silicon via Metal Performance Shaders (MPS). +It can process a batch of FSTs at the same time. .. [#f1] Support for TensorFlow will be added in the future. diff --git a/docs/source/installation/from_source.rst b/docs/source/installation/from_source.rst index 75af6bd58..044afedce 100644 --- a/docs/source/installation/from_source.rst +++ b/docs/source/installation/from_source.rst @@ -12,7 +12,7 @@ Install from source .. hint:: - It supports Linux (CPU + CUDA), macOS (CPU), and Windows (CPU + CUDA). + It supports Linux (CPU + CUDA), macOS (CPU + MPS), and Windows (CPU + CUDA). .. hint:: @@ -51,6 +51,37 @@ After setting up the environment, we are ready to build k2: That is all you need to run. +.. _install k2 with mps: + +Building with Apple Silicon (MPS) support +------------------------------------------ + +On macOS with an M-series chip, you can enable the Metal Performance Shaders +backend by passing ``-DK2_WITH_MPS=ON`` to CMake: + +.. code-block:: bash + + git clone https://github.com/k2-fsa/k2.git + cd k2 + export K2_CMAKE_ARGS="-DK2_WITH_MPS=ON" + export K2_MAKE_ARGS="-j6" + python3 setup.py install + +To verify that MPS support is active: + +.. code-block:: python + + import k2 + print(k2.with_mps) # True on Apple Silicon with MPS build + +.. hint:: + + PyTorch >= 2.2 is required for MPS support. Float64 operations and a small + number of non-differentiable internal functions are not available on MPS; + the differentiable public API (``get_forward_scores``, ``get_tot_scores``, + ``get_arc_post``, ``intersect_dense``, ``mutual_information_recursion``) + works fully on MPS. + .. hint:: We use ``export K2_MAKE_ARGS="-j6"`` to pass ``-j6`` to ``make`` diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 5b0442a47..0b7713ce1 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -145,6 +145,14 @@ if(K2_USE_PYTORCH) if(DEFINED ENV{CONDA_PREFIX} AND APPLE) target_link_libraries(context PUBLIC "-L $ENV{CONDA_PREFIX}/lib") endif() + + if(APPLE) + # On macOS with venv-style PyTorch installs, libtorch_python.dylib + # (which provides autograd metadata symbols) is loaded at Python runtime, + # not at link time. Allow undefined symbols so intermediate shared + # libraries (libk2context.dylib etc.) link successfully. + target_link_libraries(context PUBLIC "-undefined dynamic_lookup") + endif() endif() target_include_directories(context PUBLIC ${PYTHON_INCLUDE_DIRS}) diff --git a/k2/csrc/array.h b/k2/csrc/array.h index d67f10d5a..c762edf53 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -374,7 +374,9 @@ ToType(int64_t, Long) K2_CHECK_LT(i, Dim()); const T *data = Data() + i; DeviceType type = Context()->GetDeviceType(); - if (type == kCpu) { + if (type == kCpu || type == kMps) { + // MPS uses MTLStorageModeShared so data_ptr() is CPU-dereferenceable, + // but callers must ensure pending Metal writes are flushed (synchronize). return *data; } else { K2_CHECK_EQ(type, kCuda); diff --git a/k2/csrc/array_ops.cu b/k2/csrc/array_ops.cu index 551bbda13..b044cf800 100644 --- a/k2/csrc/array_ops.cu +++ b/k2/csrc/array_ops.cu @@ -23,6 +23,9 @@ #include "k2/csrc/array_ops.h" #include "k2/csrc/macros.h" #include "k2/csrc/nvtx.h" +#ifdef K2_WITH_MPS +#include "k2/csrc/mps_utils.h" +#endif namespace k2 { @@ -69,7 +72,11 @@ Array1 SpliceRowSplits(int32_t num_arrays, ExclusiveSumDeref(last_elems_ptrs, &data_offsets); int32_t *data_offsets_data = data_offsets.Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + // Synchronize Metal command queue so the CPU loop reads committed data. + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif // a simple loop is faster, although the other branches should still work on // CPU. for (int32_t i = 0; i < num_arrays; i++) { @@ -128,7 +135,11 @@ Array1 CatWithOffsets(const Array1 &offsets, int32_t *ans_data = ans.Data(); const int32_t *offsets_data = offsets.Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + // Synchronize Metal command queue so the CPU loop reads committed data. + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif for (int32_t i = 0; i != num_arrays; ++i) { int32_t this_dim = src[i]->Dim(); const int32_t *this_src_data = src[i]->Data(); @@ -305,6 +316,12 @@ Array1 GetCounts(ContextPtr c, const int32_t *src_data, } DeviceType d = c->GetDeviceType(); +#ifdef K2_WITH_MPS + if (d == kMps) { + mps_ops::GetCountsMps(src_data, src_dim, ans_data, n); + return ans; + } +#endif if (d == kCpu) { for (int32_t i = 0; i < src_dim; ++i) { ++ans_data[src_data[i]]; @@ -465,7 +482,11 @@ Array1 SizesToMergeMap(ContextPtr c, if (tot_size == 0) return ans; uint32_t *ans_data = ans.Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + // Synchronize Metal command queue before writing directly to the buffer. + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif int32_t cur = 0; for (int32_t src = 0; src != num_srcs; ++src) { int32_t begin = cur, // i.e. the previous end. diff --git a/k2/csrc/array_ops_inl.h b/k2/csrc/array_ops_inl.h index 94607ec2b..683cef607 100644 --- a/k2/csrc/array_ops_inl.h +++ b/k2/csrc/array_ops_inl.h @@ -243,7 +243,7 @@ void Transpose(ContextPtr &c, const Array2 &src, Array2 *dest) { const T *src_data = src.Data(); T *dest_data = dest->Data(); DeviceType d = c->GetDeviceType(); - if (d == kCpu) { + if (d == kCpu || d == kMps) { for (int32_t i = 0; i < cols; ++i) { for (int32_t j = 0; j < rows; ++j) { dest_data[i * dest_elem_stride0 + j] = @@ -343,9 +343,13 @@ Array1 Cat(ContextPtr c, int32_t num_arrays, const Array1 **src) { if (ans_size == 0) return ans; T *ans_data = ans.Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { // a simple loop is faster, although the other branches should still work on - // CPU. + // CPU. MPS uses MTLStorageModeShared so data_ptr() is CPU-accessible, but + // pending Metal writes must be flushed first via synchronize(). +#ifdef K2_WITH_MPS + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif int64_t elem_size = src[0]->ElementSize(); for (int32_t i = 0; i < num_arrays; ++i) { int32_t this_dim = src[i]->Dim(); @@ -740,7 +744,10 @@ void MonotonicLowerBound(const Array1 &src, Array1 *dest) { const S *src_data = src.Data(); T *dest_data = dest->Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif S min_value = std::numeric_limits::max(); for (int32_t i = dim - 1; i >= 0; --i) { min_value = std::min(src_data[i], min_value); @@ -779,7 +786,10 @@ void MonotonicDecreasingUpperBound(const Array1 &src, Array1 *dest) { const S *src_data = src.Data(); T *dest_data = dest->Data(); - if (c->GetDeviceType() == kCpu) { + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif S max_value = std::numeric_limits::min(); for (int32_t i = dim - 1; i >= 0; --i) { max_value = std::max(src_data[i], max_value); @@ -1038,7 +1048,12 @@ T Sum(ContextPtr c, const T *src, int32_t dim) { NVTX_RANGE(K2_FUNC); if (dim == 0) return 0; - if (c->GetDeviceType() == kCpu) return std::accumulate(src, src + dim, T(0)); + if (c->GetDeviceType() == kCpu || c->GetDeviceType() == kMps) { +#ifdef K2_WITH_MPS + if (c->GetDeviceType() == kMps) torch::mps::synchronize(); +#endif + return std::accumulate(src, src + dim, T(0)); + } K2_CHECK_EQ(c->GetDeviceType(), kCuda); diff --git a/k2/csrc/context.cu b/k2/csrc/context.cu index 978b92db5..3670110d4 100644 --- a/k2/csrc/context.cu +++ b/k2/csrc/context.cu @@ -66,7 +66,7 @@ ParallelRunnerActive::ParallelRunnerActive(ContextPtr c) : c_(c) { cudaStream_t ParallelRunnerActive::NewStream( std::size_t num_work_items /*=0*/) { DeviceType d = c_->GetDeviceType(); - if (d == kCpu) { + if (d == kCpu || d == kMps) { return kCudaStreamInvalid; } else { K2_CHECK_EQ(d, kCuda); diff --git a/k2/csrc/context.h b/k2/csrc/context.h index 75a34480f..4c5959abe 100644 --- a/k2/csrc/context.h +++ b/k2/csrc/context.h @@ -55,11 +55,13 @@ enum class DeviceType { kUnk, kCuda, kCpu, + kMps, }; constexpr DeviceType kUnk = DeviceType::kUnk; constexpr DeviceType kCuda = DeviceType::kCuda; constexpr DeviceType kCpu = DeviceType::kCpu; +constexpr DeviceType kMps = DeviceType::kMps; // Intended for use in debugging inline std::ostream &operator<<(std::ostream &stream, const DeviceType type) { @@ -73,6 +75,9 @@ inline std::ostream &operator<<(std::ostream &stream, const DeviceType type) { case kCpu: stream << "kCpu"; break; + case kMps: + stream << "kMps"; + break; default: K2_LOG(FATAL) << "Unreachable code!"; } @@ -336,6 +341,11 @@ ContextPtr GetCpuContext(); // CAUTION: If there are no CUDA capable GPUs, it returns a CPU context! ContextPtr GetCudaContext(int32_t gpu_id = -1); +// Return a Context object suitable for work with MPS (Metal Performance +// Shaders) on Apple Silicon. +// CAUTION: If MPS is not available, it returns a CPU context! +ContextPtr GetMpsContext(); + /* Returns a (CPU) context that will allocate pinned memory. (This is CPU memory that's pinned for faster GPU memory transfers). May or may not return the same value as ::k2::GetCpuContext()... this is so, for instance, diff --git a/k2/csrc/device_guard.h b/k2/csrc/device_guard.h index 2ef0efc82..02ffcd232 100644 --- a/k2/csrc/device_guard.h +++ b/k2/csrc/device_guard.h @@ -35,7 +35,7 @@ class DeviceGuard { new_device_ = c->GetDeviceId(); if (old_device_ != new_device_) SetDevice(new_device_); } - // else do nothing + // else do nothing (CPU and MPS contexts have no CUDA device to guard) } explicit DeviceGuard(int32_t new_device) : new_device_(new_device) { diff --git a/k2/csrc/macros.h b/k2/csrc/macros.h index 9d9c39084..bb5415704 100644 --- a/k2/csrc/macros.h +++ b/k2/csrc/macros.h @@ -61,7 +61,8 @@ You can replace the above code with `K2_EVAL` by the following code: #define K2_EVAL(context, dim, lambda_name, ...) \ do { \ - if (context->GetDeviceType() == kCpu) { \ + if (context->GetDeviceType() == kCpu || \ + context->GetDeviceType() == kMps) { \ auto lambda_name = [=] __VA_ARGS__; \ int32_t lambda_name##_dim = dim; \ for (int32_t i = 0; i != lambda_name##_dim; ++i) lambda_name(i); \ @@ -110,7 +111,8 @@ You can replace the above code with `K2_EVAL2` by the following code: */ #define K2_EVAL2(context, m, n, lambda_name, ...) \ do { \ - if (context->GetDeviceType() == kCpu) { \ + if (context->GetDeviceType() == kCpu || \ + context->GetDeviceType() == kMps) { \ auto lambda_name = [=] __VA_ARGS__; \ int32_t lambda_name##_m = m; \ int32_t lambda_name##_n = n; \ @@ -143,7 +145,8 @@ Here is an example: */ #define K2_TRANS_EXCSUM(context, dim, ans_data, lambda_name, ...) \ do { \ - if (context->GetDeviceType() == kCpu) { \ + if (context->GetDeviceType() == kCpu || \ + context->GetDeviceType() == kMps) { \ auto lambda_name = [=] __VA_ARGS__; \ int32_t lambda_name##_dim = dim; \ ans_data[0] = 0; \ diff --git a/k2/csrc/mps_utils.h b/k2/csrc/mps_utils.h new file mode 100644 index 000000000..5acfc3737 --- /dev/null +++ b/k2/csrc/mps_utils.h @@ -0,0 +1,142 @@ +/** + * Copyright 2026 k2-fsa Authors + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MPS-accelerated implementations of k2 utility operations using ATen ops. +// Only compiled when K2_WITH_MPS is defined. +// +// Priority-2 optimisations (PyTorch >= 2.2 MPS): +// • All intermediate int64 casts removed — MPS now supports int32 cumsum, +// diff, repeat_interleave, searchsorted, and bincount natively. +// • ExclusiveSumMps no longer allocates a zeroed intermediate tensor; +// it uses constant_pad_nd to prepend the leading 0 in one fused Metal op. +// • InclusiveSumMps reduced from 3 ATen ops to 1 (direct int32 cumsum). +// • RowSplitsToRowIdsMps diff cast removed; repeat_interleave uses int32 +// counts. +// • MaxSizeMps diff + max now operate in int32 throughout. +#pragma once + +#ifdef K2_WITH_MPS + +#include "torch/torch.h" +#include "k2/csrc/pytorch_context.h" + +namespace k2 { +namespace mps_ops { + +// Returns a Metal-safe 1-D view of `numel` elements of type `dtype` +// starting at `ptr`. Uses the global MPS registry (via MpsRegistryView) +// to find the PyTorch-owned base tensor and create a proper narrow() view. +inline torch::Tensor AsMpsTensor(void *ptr, int64_t numel, + torch::ScalarType dtype = torch::kInt32) { + return MpsRegistryView(ptr, numel, dtype); +} +inline torch::Tensor AsMpsTensor(const void *ptr, int64_t numel, + torch::ScalarType dtype = torch::kInt32) { + return MpsRegistryView(ptr, numel, dtype); +} + +// ExclusiveSumMps: dest[i] = sum_{j= 2.2). +inline void InclusiveSumMps(int32_t n, const int32_t *src, int32_t *dest) { + if (n == 0) return; + auto src_t = AsMpsTensor(src, (int64_t)n); + auto dst_t = AsMpsTensor(dest, (int64_t)n); + dst_t.copy_(src_t.cumsum(0)); +} + +// RowSplitsToRowIdsMps: row_ids[i] = j where +// row_splits[j] <= i < row_splits[j+1]. +// +// int32 diff and repeat_interleave(int32 counts) are both supported on MPS. +inline void RowSplitsToRowIdsMps(int32_t num_rows, const int32_t *row_splits, + int32_t num_elems, int32_t *row_ids) { + if (num_rows <= 0 || num_elems <= 0) return; + auto row_splits_t = AsMpsTensor(row_splits, (int64_t)(num_rows + 1)); + auto row_ids_t = AsMpsTensor(row_ids, (int64_t)num_elems); + auto counts = torch::diff(row_splits_t); // int32, length num_rows + auto arange = torch::arange((int64_t)num_rows, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kMPS)); + row_ids_t.copy_(torch::repeat_interleave(arange, counts)); +} + +// RowIdsToRowSplitsMps: row_splits[j] = number of elements strictly before +// row j. +// +// int32 searchsorted is supported on MPS. +inline void RowIdsToRowSplitsMps(int32_t num_elems, const int32_t *row_ids, + int32_t num_rows, int32_t *row_splits) { + auto mps_i32 = torch::TensorOptions() + .dtype(torch::kInt32).device(torch::kMPS); + auto row_ids_t = AsMpsTensor(row_ids, (int64_t)num_elems); + auto row_splits_t = AsMpsTensor(row_splits, (int64_t)(num_rows + 1)); + auto boundaries = torch::arange((int64_t)num_rows, mps_i32); + // searchsorted returns int64; the copy_ into int32 row_splits_t will cast. + auto result = torch::searchsorted(row_ids_t.contiguous(), boundaries); + row_splits_t.slice(0, 0, num_rows).copy_(result); + row_splits_t.slice(0, num_rows).fill_(num_elems); +} + +// MaxSizeMps: max over (row_splits[i+1] - row_splits[i]) +// for i in [0, num_rows). +// +// int32 diff + max avoids the previous int64 round-trip. +inline int32_t MaxSizeMps(int32_t num_rows, const int32_t *row_splits) { + if (num_rows == 0) return 0; + auto row_splits_t = AsMpsTensor(row_splits, (int64_t)(num_rows + 1)); + return torch::diff(row_splits_t).max().item(); +} + +// GetCountsMps: ans[v] = number of times v appears in src[0..src_dim). +// +// bincount requires int64 input by design (PyTorch API constraint). +inline void GetCountsMps(const int32_t *src_data, int32_t src_dim, + int32_t *ans_data, int32_t n) { + if (n == 0) return; + auto src_t = AsMpsTensor(src_data, (int64_t)src_dim); + auto ans_t = AsMpsTensor(ans_data, (int64_t)n); + ans_t.copy_(torch::bincount(src_t.to(torch::kInt64), {}, (int64_t)n) + .to(torch::kInt32)); +} + +} // namespace mps_ops +} // namespace k2 + +#endif // K2_WITH_MPS diff --git a/k2/csrc/pinned_context.cu b/k2/csrc/pinned_context.cu index b46270a87..a55a6bdd0 100644 --- a/k2/csrc/pinned_context.cu +++ b/k2/csrc/pinned_context.cu @@ -360,6 +360,10 @@ ContextPtr GetContextForTransfer(DeviceType device_type) { return GetCpuContext(); case kCuda: return GetPinnedContext(); + case kMps: + // MPS uses unified memory on Apple Silicon; plain CPU context suffices + // for staging transfers. + return GetCpuContext(); default: K2_LOG(FATAL) << "Unsupported device type: " << device_type; return nullptr; diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index f732d9794..cfbac80ba 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -16,6 +16,7 @@ * limitations under the License. */ +#include #include #include // NOLINT @@ -25,12 +26,54 @@ #include "torch/cuda.h" #endif +#ifdef K2_WITH_MPS +#include "torch/mps.h" +#endif + #include "k2/csrc/context.h" #include "k2/csrc/device_guard.h" #include "k2/csrc/log.h" #include "k2/csrc/pytorch_context.h" namespace k2 { + +#ifdef K2_WITH_MPS +// Global registry mapping MPS data pointers to their base tensors. +// When k2 allocates MPS memory via PytorchMpsContext::Allocate(), the +// resulting tensor is registered here so that CPU→MPS copies can use +// PyTorch's Metal-safe copy_() instead of a raw memcpy. +// A raw memcpy to a Metal buffer bypasses PyTorch's hazard tracking and +// causes subsequent Metal operations (e.g. .to('cpu')) to crash. +static std::mutex g_mps_registry_mutex; +static std::map g_mps_registry; + +torch::Tensor MpsRegistryView(const void *ptr, int64_t n, + torch::ScalarType dtype) { + std::lock_guard lock(g_mps_registry_mutex); + const char *p = reinterpret_cast(ptr); + // Use upper_bound for O(log n) range lookup: find the first entry whose + // base pointer is strictly greater than ptr, then step back one to get + // the candidate allocation that may contain ptr. + auto it = g_mps_registry.upper_bound(const_cast(ptr)); + if (it != g_mps_registry.begin()) { + --it; + const char *base = reinterpret_cast(it->first); + int64_t base_bytes = it->second.nbytes(); + if (p >= base && p < base + base_bytes) { + ptrdiff_t byte_off = p - base; + int64_t elem_size = torch::elementSize(dtype); + K2_CHECK_EQ(byte_off % elem_size, 0) + << "Unaligned MPS pointer: byte_off=" << byte_off + << " elem_size=" << elem_size; + return it->second.view(dtype).narrow(0, byte_off / elem_size, n); + } + } + K2_LOG(FATAL) << "MPS pointer " << ptr + << " not found in registry — was it allocated by " + "PytorchMpsContext::Allocate()?"; + return {}; // unreachable +} +#endif // CAUTION: This is a workaround to free the CUDA memory // correctly if `PYTORCH_NO_CUDA_MEMORY_CACHING` is set. // @@ -130,6 +173,41 @@ class PytorchCpuContext : public Context { pinned_context->CopyDataTo(num_bytes, region->data, dst_context, dst); break; } +#ifdef K2_WITH_MPS + case kMps: { + // CPU -> MPS: a raw memcpy bypasses PyTorch's Metal hazard tracking, + // causing subsequent Metal operations (e.g. .to('cpu')) to crash. + // Use PyTorch's Metal-safe copy_() via the global registry instead. + torch::Tensor dst_base; + bool found = false; + { + std::lock_guard lock(g_mps_registry_mutex); + auto it = g_mps_registry.find(dst); + if (it != g_mps_registry.end()) { + dst_base = it->second; + found = true; + } + } + if (found) { + // Wrap the CPU source in a from_blob tensor (safe for CPU). + auto src_cpu = torch::from_blob( + const_cast(src), {static_cast(num_bytes)}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU)); + // Compute byte offset of dst within dst_base (0 for base pointers). + int64_t byte_off = static_cast( + static_cast(dst) - + static_cast(dst_base.data_ptr())); + // Use Metal copy_() to write into the MPS buffer. + dst_base.narrow(0, byte_off, static_cast(num_bytes)) + .copy_(src_cpu); + } else { + // Fallback for MPS regions not created by Allocate() (e.g. wrapped + // from a Python tensor via NewRegion(torch::Tensor)). + memcpy(dst, src, num_bytes); + } + break; + } +#endif default: K2_LOG(FATAL) << "Unsupported device type: " << device_type; break; @@ -153,7 +231,8 @@ class PytorchCudaContext : public Context { // so it is fine to invoke lazyInitCUDA() multiple times. // The call will be inlined since it is defined in the header // aten/src/ATen/Context.h -#if K2_TORCH_VERSION_MAJOR > 2 || (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 6) +#if K2_TORCH_VERSION_MAJOR > 2 || \ + (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 6) at::globalContext().lazyInitDevice(torch::kCUDA); #else at::globalContext().lazyInitCUDA(); @@ -249,6 +328,140 @@ class PytorchCudaContext : public Context { int32_t gpu_id_; }; +#ifdef K2_WITH_MPS +class PytorchMpsContext : public Context { + public: + PytorchMpsContext() {} + + DeviceType GetDeviceType() const override { return kMps; } + + // Apple Silicon has a single MPS device. + int32_t GetDeviceId() const override { return 0; } + + // MPS uses Metal command queues, not CUDA streams; return the sentinel so + // that Eval() in eval.h routes MPS through the CPU sequential loop. + cudaStream_t GetCudaStream() const override { return kCudaStreamInvalid; } + + void *Allocate(std::size_t bytes, void **deleter_context) override { + // Allocate via torch::Tensor so PyTorch's MPS memory manager owns the + // Metal buffer. The tensor is kept alive via ManagedTensor stored in + // deleter_context; Deallocate() drops it to release the buffer. + auto tensor = torch::empty({static_cast(bytes)}, + torch::TensorOptions() + .dtype(torch::kByte) + .device(torch::kMPS)); + void *p = tensor.data_ptr(); + if (deleter_context != nullptr) { + *deleter_context = new ManagedTensor(tensor); + // Register in the global registry so that CPU→MPS copies (in + // PytorchCpuContext::CopyDataTo) can use Metal-safe copy_() to write + // into this buffer rather than a raw memcpy. + std::lock_guard lock(g_mps_registry_mutex); + g_mps_registry[p] = tensor; + } else { + // Caller opted out of tracking; memory will be released when the + // tensor goes out of scope here. This should not happen in practice + // because k2 always passes a non-null deleter_context via NewRegion(). + K2_LOG(FATAL) << "PytorchMpsContext::Allocate called with null " + "deleter_context — MPS memory would be immediately " + "freed. This is a k2 bug."; + } + return p; + } + + void Deallocate(void *data, void *deleter_context) override { + if (deleter_context != nullptr) { + // Unregister from the global registry before freeing the tensor. + { + std::lock_guard lock(g_mps_registry_mutex); + g_mps_registry.erase(data); + } + // deleter_context holds a ManagedTensor; dropping it releases the MPS + // buffer back to PyTorch's allocator. + delete reinterpret_cast(deleter_context); + } else { + // Should not happen: every Allocate() stores a ManagedTensor. + K2_LOG(FATAL) << "PytorchMpsContext::Deallocate called with null " + "deleter_context — cannot free MPS memory."; + } + } + + bool IsCompatible(const Context &other) const override { + return other.GetDeviceType() == kMps; + } + + void Sync() const override { torch::mps::synchronize(); } + + void CopyDataTo(size_t num_bytes, const void *src, ContextPtr dst_context, + void *dst) override { + DeviceType device_type = dst_context->GetDeviceType(); + switch (device_type) { + case kCpu: { + // MPS -> CPU: PyTorch's MPS backend stores tensor data in Metal + // buffers that are NOT at data_ptr() on the CPU side. We must use + // Metal-aware copy_() rather than a raw memcpy. + torch::Tensor src_base; + { + std::lock_guard lock(g_mps_registry_mutex); + auto it = g_mps_registry.find(const_cast(src)); + if (it != g_mps_registry.end()) src_base = it->second; + } + if (src_base.defined()) { + int64_t byte_off = static_cast( + static_cast(src) - + static_cast(src_base.data_ptr())); + auto src_view = + src_base.narrow(0, byte_off, static_cast(num_bytes)); + // Create a CPU tensor view of dst and copy via Metal. + auto dst_cpu = torch::from_blob( + dst, {static_cast(num_bytes)}, + torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU)); + dst_cpu.copy_(src_view); + } else { + // Fallback for regions not created by Allocate() (e.g. wrapped from + // a Python tensor). Data may be stale if copy_() was used for + // the CPU->MPS write. + torch::mps::synchronize(); + memcpy(dst, src, num_bytes); + } + break; + } + case kMps: { + // MPS -> MPS: use Metal copy_() to stay within Metal's hazard + // tracking system. + torch::Tensor src_base, dst_base; + { + std::lock_guard lock(g_mps_registry_mutex); + auto sit = g_mps_registry.find(const_cast(src)); + if (sit != g_mps_registry.end()) src_base = sit->second; + auto dit = g_mps_registry.find(dst); + if (dit != g_mps_registry.end()) dst_base = dit->second; + } + if (src_base.defined() && dst_base.defined()) { + int64_t src_off = static_cast( + static_cast(src) - + static_cast(src_base.data_ptr())); + int64_t dst_off = static_cast( + static_cast(dst) - + static_cast(dst_base.data_ptr())); + auto src_view = + src_base.narrow(0, src_off, static_cast(num_bytes)); + dst_base.narrow(0, dst_off, static_cast(num_bytes)) + .copy_(src_view); + } else { + // Fallback: memcpy for regions not in the registry. + memcpy(dst, src, num_bytes); + } + break; + } + default: + K2_LOG(FATAL) << "Unsupported device type: " << device_type; + break; + } + } +}; +#endif // K2_WITH_MPS + ContextPtr GetCpuContext() { return std::make_shared(); } ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { @@ -268,15 +481,35 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) { return GetCpuContext(); } +ContextPtr GetMpsContext() { +#ifdef K2_WITH_MPS + if (torch::mps::is_available()) { + // Trigger lazy MPS backend initialization so the allocator is registered + // before PytorchMpsContext tries to fetch it. + torch::empty({0}, torch::TensorOptions().device(torch::kMPS)); + return std::make_shared(); + } + K2_LOG(WARNING) << "MPS is not available. Falling back to CPU context."; +#else + K2_LOG(WARNING) << "k2 was not compiled with MPS support. " + "Falling back to CPU context."; +#endif + return GetCpuContext(); +} + RegionPtr NewRegion(torch::Tensor tensor) { auto ans = std::make_shared(); if (tensor.device().type() == torch::kCPU) { ans->context = GetCpuContext(); } else if (tensor.is_cuda()) { ans->context = GetCudaContext(tensor.device().index()); +#ifdef K2_WITH_MPS + } else if (tensor.device().type() == torch::kMPS) { + ans->context = GetMpsContext(); +#endif } else { K2_LOG(FATAL) << "Unsupported device: " << tensor.device() - << "\nOnly CPU and CUDA are supported"; + << "\nOnly CPU, CUDA, and MPS are supported"; } // NOTE: the tensor is passed from Python and we have diff --git a/k2/csrc/pytorch_context.h b/k2/csrc/pytorch_context.h index 3ae34841c..bd03523a0 100644 --- a/k2/csrc/pytorch_context.h +++ b/k2/csrc/pytorch_context.h @@ -31,6 +31,10 @@ class ManagedTensor { public: explicit ManagedTensor(torch::Tensor tensor) : handle_(tensor) {} + // Return the underlying tensor. Used by ToTorch() for MPS arrays to + // create a proper MPS tensor view rather than an unsafe from_blob(). + const torch::Tensor &tensor() const { return handle_; } + private: torch::Tensor handle_; // retain a copy of the tensor passed from Python }; @@ -41,6 +45,21 @@ class ManagedTensor { // the given tensor. RegionPtr NewRegion(torch::Tensor tensor); +// Returns a context for the MPS (Metal Performance Shaders) device. +// On non-Apple platforms or when MPS is unavailable, falls back to CPU. +ContextPtr GetMpsContext(); + +#ifdef K2_WITH_MPS +// Finds the k2-allocated MPS base tensor whose byte range contains `ptr` +// and returns a Metal-safe view of `n` elements of type `dtype`. +// The returned tensor is a proper PyTorch MPS view (not from_blob), so +// ATen ops dispatched on it correctly use the Metal command queue. +// Aborts if `ptr` is not found in the registry (i.e. was not allocated +// by PytorchMpsContext::Allocate). +torch::Tensor MpsRegistryView(const void *ptr, int64_t n, + torch::ScalarType dtype); +#endif // K2_WITH_MPS + } // namespace k2 #endif // K2_CSRC_PYTORCH_CONTEXT_H_ diff --git a/k2/csrc/ragged.cu b/k2/csrc/ragged.cu index f76220d53..eb5e0a93e 100644 --- a/k2/csrc/ragged.cu +++ b/k2/csrc/ragged.cu @@ -23,6 +23,9 @@ #include "k2/csrc/macros.h" #include "k2/csrc/math.h" #include "k2/csrc/ragged.h" +#ifdef K2_WITH_MPS +#include "k2/csrc/mps_utils.h" +#endif namespace { @@ -142,6 +145,11 @@ int32_t RaggedShape::MaxSize(int32_t axis) { if (num_rows == 0) return 0; const int32_t *row_splits_data = row_splits.Data(); ContextPtr c = Context(); +#ifdef K2_WITH_MPS + if (c->GetDeviceType() == kMps) { + return mps_ops::MaxSizeMps(num_rows, row_splits_data); + } +#endif if (c->GetDeviceType() == kCpu) { int32_t max_value = 0; for (int32_t i = 0; i < num_rows; ++i) { diff --git a/k2/csrc/torch_util.cu b/k2/csrc/torch_util.cu index 71d3ce4d3..916911208 100644 --- a/k2/csrc/torch_util.cu +++ b/k2/csrc/torch_util.cu @@ -30,6 +30,10 @@ torch::DeviceType ToTorchDeviceType(DeviceType type) { return torch::kCUDA; case kCpu: return torch::kCPU; +#ifdef K2_WITH_MPS + case kMps: + return torch::kMPS; +#endif case kUnk: // fall-through default: K2_LOG(FATAL) << "kUnk is not supported!"; @@ -43,9 +47,18 @@ DeviceType FromTorchDeviceType(const torch::DeviceType &type) { return kCuda; case torch::kCPU: return kCpu; +#ifdef K2_WITH_MPS + case torch::kMPS: + return kMps; +#endif default: K2_LOG(FATAL) << "Unsupported device type: " << type +#ifdef K2_WITH_MPS + << ". Only torch::kCUDA, torch::kCPU, and torch::kMPS " + "are supported"; +#else << ". Only torch::kCUDA and torch::kCPU are supported"; +#endif return kUnk; // unreachable code } } @@ -97,6 +110,24 @@ torch::Tensor ToTorch(Array1 &array) { auto options = torch::device(device).dtype(scalar_type); if (array.Dim() == 0) return torch::empty({0, 4}, options); +#ifdef K2_WITH_MPS + if (device_type == torch::kMPS) { + // Use a proper view of ManagedTensor rather than from_blob (see comment + // in the Array1 specialization in torch_util.h for details). + auto region = array.GetRegion(); + K2_CHECK(region && region->deleter_context != nullptr) + << "MPS Array1 region has no ManagedTensor"; + auto *mt = reinterpret_cast(region->deleter_context); + const torch::Tensor &base = mt->tensor(); + ptrdiff_t byte_off = reinterpret_cast(array.Data()) - + reinterpret_cast(region->data); + auto int32_base = base.view(scalar_type); // [bytes/4] kInt32 + int64_t elem_off = static_cast(byte_off) / sizeof(int32_t); + // Narrow to [elem_off .. elem_off + dim*4], then reshape to [dim, 4]. + return int32_base.narrow(0, elem_off, array.Dim() * 4).view(sizes); + } +#endif + // NOTE: we keep a copy of `Region` inside the lambda // so that the returned tensor outlives the input array. return torch::from_blob( @@ -156,6 +187,9 @@ torch::Tensor ToTorch(Tensor &tensor) { ContextPtr GetContext(torch::Device device) { if (device.type() == torch::kCPU) return GetCpuContext(); +#ifdef K2_WITH_MPS + if (device.type() == torch::kMPS) return GetMpsContext(); +#endif K2_CHECK_EQ(device.type(), torch::kCUDA); return GetCudaContext(device.index()); diff --git a/k2/csrc/torch_util.h b/k2/csrc/torch_util.h index 1e48f9173..8e3dfbde0 100644 --- a/k2/csrc/torch_util.h +++ b/k2/csrc/torch_util.h @@ -88,14 +88,28 @@ torch::Tensor ToTorch(Array1 &array) { auto device = torch::Device(device_type, device_id); auto scalar_type = ToScalarType::value; auto options = torch::device(device).dtype(scalar_type); - // We will call torch::from_blob below. However, if we - // call it with an empty Array1, we'll get error: - // RuntimeError: CUDA error: invalid argument Exception raised from - // getDeviceFromPtr at /pytorch/aten/src/ATen/cuda/CUDADevice.h - // Definitely we need look into this, but let's just return an empty tensor - // when the input Array1 is empty for now. if (array.Dim() == 0) return torch::empty(0, options); +#ifdef K2_WITH_MPS + if (device_type == torch::kMPS) { + // torch::from_blob() with a custom deleter bypasses PyTorch's MPS Metal + // buffer tracking, causing crashes on subsequent MPS→CPU transfers. + // Return a proper view of the ManagedTensor instead. + auto region = array.GetRegion(); + K2_CHECK(region && region->deleter_context != nullptr) + << "MPS Array1 region has no ManagedTensor — cannot create safe view"; + auto *mt = reinterpret_cast(region->deleter_context); + const torch::Tensor &base = mt->tensor(); + // Byte offset from the region base to the first array element. + ptrdiff_t byte_off = reinterpret_cast(array.Data()) - + reinterpret_cast(region->data); + // Reinterpret the base tensor as scalar_type and narrow to the slice. + auto t_view = base.view(scalar_type); + int64_t elem_off = static_cast(byte_off) / sizeof(T); + return t_view.narrow(0, elem_off, array.Dim()); + } +#endif + // NOTE: we keep a copy of `Region` inside the lambda // so that `torch::Tensor` always accesses valid memory. return torch::from_blob( diff --git a/k2/csrc/utils.cu b/k2/csrc/utils.cu index 66aa9861b..0a4ced9fa 100644 --- a/k2/csrc/utils.cu +++ b/k2/csrc/utils.cu @@ -25,6 +25,9 @@ #include "k2/csrc/moderngpu_allocator.h" #include "k2/csrc/nvtx.h" #include "k2/csrc/utils.h" +#ifdef K2_WITH_MPS +#include "k2/csrc/mps_utils.h" +#endif namespace k2 { @@ -78,6 +81,12 @@ void RowSplitsToRowIds(ContextPtr c, int32_t num_rows, NVTX_RANGE(K2_FUNC); if (num_rows <= 0 || num_elems <= 0) return; DeviceType d = c->GetDeviceType(); +#ifdef K2_WITH_MPS + if (d == kMps) { + mps_ops::RowSplitsToRowIdsMps(num_rows, row_splits, num_elems, row_ids); + return; + } +#endif if (d == kCpu) { int32_t cur_row_start = row_splits[0]; K2_CHECK_EQ(cur_row_start, 0); @@ -168,6 +177,12 @@ void RowIdsToRowSplits(ContextPtr c, int32_t num_elems, const int32_t *row_ids, return; } DeviceType d = c->GetDeviceType(); +#ifdef K2_WITH_MPS + if (d == kMps) { + mps_ops::RowIdsToRowSplitsMps(num_elems, row_ids, num_rows, row_splits); + return; + } +#endif if (d == kCpu) { int32_t cur_row = -1; for (int32_t i = 0; i < num_elems; i++) { diff --git a/k2/csrc/utils_inl.h b/k2/csrc/utils_inl.h index a8a4b3eb6..9ea32a49e 100644 --- a/k2/csrc/utils_inl.h +++ b/k2/csrc/utils_inl.h @@ -20,10 +20,14 @@ #ifndef K2_CSRC_UTILS_INL_H_ #define K2_CSRC_UTILS_INL_H_ +#include #include #include "k2/csrc/array.h" #include "k2/csrc/cub.h" +#ifdef K2_WITH_MPS +#include "k2/csrc/mps_utils.h" +#endif namespace k2 { template @@ -31,6 +35,29 @@ void ExclusiveSum(ContextPtr c, int32_t n, const SrcPtr src, DestPtr dest) { K2_CHECK_GE(n, 0); DeviceType d = c->GetDeviceType(); using SumType = typename std::decay::type; +#ifdef K2_WITH_MPS + if (d == kMps) { + // Dispatch to Metal-safe ATen cumsum. Only int32_t raw pointers are + // supported on MPS (all k2 row_splits / row_ids use int32_t). + using RawSrc = std::decay_t; + using RawDest = std::decay_t; + if constexpr (std::is_pointer_v && std::is_pointer_v && + std::is_same_v< + std::remove_cv_t>, + int32_t> && + std::is_same_v< + std::remove_cv_t>, + int32_t>) { + mps_ops::ExclusiveSumMps(n, + reinterpret_cast(src), + reinterpret_cast(dest)); + } else { + K2_LOG(FATAL) + << "ExclusiveSum on MPS only supports int32_t raw pointers"; + } + return; + } +#endif if (d == kCpu) { SumType sum = 0; for (int32_t i = 0; i != n; ++i) { @@ -65,6 +92,27 @@ void InclusiveSum(ContextPtr c, int32_t n, const SrcPtr src, DestPtr dest) { K2_CHECK_GE(n, 0); DeviceType d = c->GetDeviceType(); using SumType = typename std::decay::type; +#ifdef K2_WITH_MPS + if (d == kMps) { + using RawSrc = std::decay_t; + using RawDest = std::decay_t; + if constexpr (std::is_pointer_v && std::is_pointer_v && + std::is_same_v< + std::remove_cv_t>, + int32_t> && + std::is_same_v< + std::remove_cv_t>, + int32_t>) { + mps_ops::InclusiveSumMps(n, + reinterpret_cast(src), + reinterpret_cast(dest)); + } else { + K2_LOG(FATAL) + << "InclusiveSum on MPS only supports int32_t raw pointers"; + } + return; + } +#endif if (d == kCpu) { SumType sum = 0; for (int32_t i = 0; i != n; ++i) { @@ -89,6 +137,20 @@ void InclusiveSum(ContextPtr c, int32_t n, const SrcPtr src, DestPtr dest) { template T MaxValue(ContextPtr c, int32_t nelems, const T *t) { DeviceType d = c->GetDeviceType(); +#ifdef K2_WITH_MPS + if (d == kMps) { + // Use ATen reduction for Metal-safe access. In k2, MaxValue on MPS is + // always called with int32_t (row_splits); other types are not supported. + if constexpr (std::is_same_v) { + return static_cast( + mps_ops::AsMpsTensor(t, static_cast(nelems)) + .max().template item()); + } else { + K2_LOG(FATAL) << "MaxValue on MPS only supports int32_t"; + return T(0); // unreachable + } + } +#endif if (d == kCpu) { // note the return value is initialized with T(0) T result = T(0); diff --git a/k2/csrc/version.h.in b/k2/csrc/version.h.in index cfffaccf8..1d97cc5a9 100644 --- a/k2/csrc/version.h.in +++ b/k2/csrc/version.h.in @@ -76,6 +76,16 @@ static constexpr const char *kTorchCudaVersion = "@TORCH_CUDA_VERSION@"; static constexpr bool kWithCuda = false; #endif +#ifndef K2_WITH_MPS +#cmakedefine K2_WITH_MPS +#endif + +#ifdef K2_WITH_MPS + static constexpr bool kWithMps = true; +#else + static constexpr bool kWithMps = false; +#endif + // Indicate whether NVTX is enabled or not #ifndef K2_ENABLE_NVTX #cmakedefine K2_ENABLE_NVTX diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index 19e34fc0c..8204100ba 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -17,7 +17,19 @@ else() endif() if(NOT K2_WITH_CUDA) - transform(OUTPUT_VARIABLE k2_srcs SRCS ${k2_srcs}) + # Separate .mm (Objective-C++) files before transform() so they are not + # renamed to .cc (which would cause Objective-C syntax errors). + set(_mm_srcs) + set(_non_mm_srcs) + foreach(_src IN LISTS k2_srcs) + if(_src MATCHES "\\.mm$") + list(APPEND _mm_srcs ${_src}) + else() + list(APPEND _non_mm_srcs ${_src}) + endif() + endforeach() + transform(OUTPUT_VARIABLE _non_mm_srcs SRCS ${_non_mm_srcs}) + set(k2_srcs ${_non_mm_srcs} ${_mm_srcs}) endif() if(WIN32) @@ -33,6 +45,10 @@ pybind11_add_module(_k2 ${k2_srcs}) target_link_libraries(_k2 PRIVATE context) target_link_libraries(_k2 PRIVATE fsa) +if(APPLE AND K2_WITH_MPS) + target_link_libraries(_k2 PRIVATE "-framework Metal" "-framework Foundation") +endif() + if(APPLE) # To fix the following error: # ImportError: /xxx/lib/_k2.cpython-38-x86_64-linux-gnu.so: undefined symbol: THPDtypeType diff --git a/k2/python/csrc/torch.h b/k2/python/csrc/torch.h index 396a82bd6..f1095de2b 100644 --- a/k2/python/csrc/torch.h +++ b/k2/python/csrc/torch.h @@ -96,19 +96,32 @@ namespace k2 { template PyClass To(PyClass &pyclass, py::object device) { std::string device_type = static_cast(device.attr("type")); +#ifdef K2_WITH_MPS + K2_CHECK(device_type == "cpu" || device_type == "cuda" || + device_type == "mps") + << "Unsupported device type: " << device_type; +#else K2_CHECK(device_type == "cpu" || device_type == "cuda") << "Unsupported device type: " << device_type; +#endif ContextPtr &context = pyclass.Context(); if (device_type == "cpu") { // CPU to CPU if (context->GetDeviceType() == kCpu) return pyclass; - // CUDA to CPU + // CUDA/MPS to CPU DeviceGuard guard(context); return pyclass.To(GetCpuContext()); } +#ifdef K2_WITH_MPS + if (device_type == "mps") { + if (context->GetDeviceType() == kMps) return pyclass; + return pyclass.To(GetMpsContext()); + } +#endif + auto index_attr = static_cast(device.attr("index")); int32_t device_index = 0; if (!index_attr.is_none()) device_index = static_cast(index_attr); @@ -118,7 +131,7 @@ PyClass To(PyClass &pyclass, py::object device) { // CUDA to CUDA return pyclass; - // CPU to CUDA + // CPU/MPS to CUDA DeviceGuard guard(device_index); return pyclass.To(GetCudaContext(device_index)); } diff --git a/k2/python/csrc/torch/CMakeLists.txt b/k2/python/csrc/torch/CMakeLists.txt index 8c6803f8a..3a00507f3 100644 --- a/k2/python/csrc/torch/CMakeLists.txt +++ b/k2/python/csrc/torch/CMakeLists.txt @@ -25,6 +25,11 @@ if (K2_WITH_CUDA) list(APPEND torch_srcs mutual_information_cuda.cu) endif() +if (K2_WITH_MPS) + list(APPEND torch_srcs mutual_information_mps.mm) + list(APPEND torch_srcs mps_fsa_scores.mm) +endif() + set(torch_srcs_with_prefix) foreach(src IN LISTS torch_srcs) list(APPEND torch_srcs_with_prefix "torch/${src}") diff --git a/k2/python/csrc/torch/arc.cu b/k2/python/csrc/torch/arc.cu index e264cafbc..f9b6d387b 100644 --- a/k2/python/csrc/torch/arc.cu +++ b/k2/python/csrc/torch/arc.cu @@ -80,9 +80,10 @@ static void PybindArcImpl(py::module &m) { if (tensor.numel() == 0) return torch::empty(tensor.sizes(), tensor.options().dtype(scalar_type)); - return torch::from_blob( - tensor.data_ptr(), tensor.sizes(), tensor.strides(), - [tensor](void *p) {}, tensor.options().dtype(scalar_type)); + // Use view() rather than from_blob so the result is a proper tensor + // view that shares storage with the input. from_blob with a custom + // no-op deleter breaks on MPS because Metal buffer metadata is lost. + return tensor.view(scalar_type); }, py::arg("tensor")); @@ -94,9 +95,7 @@ static void PybindArcImpl(py::module &m) { if (tensor.numel() == 0) return torch::empty(tensor.sizes(), tensor.options().dtype(scalar_type)); - return torch::from_blob( - tensor.data_ptr(), tensor.sizes(), tensor.strides(), - [tensor](void *p) {}, tensor.options().dtype(scalar_type)); + return tensor.view(scalar_type); }, py::arg("tensor")); } diff --git a/k2/python/csrc/torch/fsa.cu b/k2/python/csrc/torch/fsa.cu index 1c8bf078f..89b71dcd0 100644 --- a/k2/python/csrc/torch/fsa.cu +++ b/k2/python/csrc/torch/fsa.cu @@ -37,6 +37,9 @@ #include "k2/csrc/torch_util.h" #include "k2/python/csrc/torch/fsa.h" #include "k2/python/csrc/torch/v2/ragged_any.h" +#ifdef K2_WITH_MPS +#include "k2/python/csrc/torch/mps_fsa_scores.h" +#endif namespace k2 { @@ -276,9 +279,9 @@ static void PybindGetForwardScores(py::module &m, const char *name) { DeviceGuard guard(fsas.Context()); Array1 entering_arcs; Array1 scores = GetForwardScores( - fsas, state_batches.any.Specialize(), - entering_arc_batches.any.Specialize(), log_semiring, - log_semiring ? nullptr : &entering_arcs); + fsas, state_batches.any.Specialize(), + entering_arc_batches.any.Specialize(), log_semiring, + log_semiring ? nullptr : &entering_arcs); torch::optional entering_arcs_tensor; if (!log_semiring) entering_arcs_tensor = ToTorch(entering_arcs); @@ -289,6 +292,70 @@ static void PybindGetForwardScores(py::module &m, const char *name) { py::arg("entering_arc_batches"), py::arg("log_semiring")); } +#ifdef K2_WITH_MPS +// PybindGetForwardScoresMps: MPS-native forward scores via native Metal kernel. +// +// Called from Python when fsas is on MPS. The caller computes +// entering_arc_batches on a CPU copy of the FSA and passes it here; +// this function copies the arc IDs to MPS and dispatches Metal compute +// kernels for each BFS batch. +static void PybindGetForwardScoresMps(py::module &m) { + m.def( + "get_forward_scores_mps", + [](FsaVec &fsas, RaggedAny &entering_arc_batches_cpu, + bool log_semiring) -> torch::Tensor { + DeviceGuard guard(fsas.Context()); + Array1 scores = mps_ops::GetForwardScoresMps( + fsas, entering_arc_batches_cpu.any.Specialize(), + log_semiring); + return ToTorch(scores); + }, + py::arg("fsas"), py::arg("entering_arc_batches_cpu"), + py::arg("log_semiring")); +} + +// PybindGetForwardScoresMpsNative: zero-copy MPS-native forward scores. +// +// Called from Python when fsas is on MPS. The caller supplies arc indices +// already sorted by BFS level (sorted_arc_ids, MPS int32 tensor) and a +// Python list of per-level arc counts (batch_sizes). This avoids the full +// FSA CPU copy required by get_forward_scores_mps. +static void PybindGetForwardScoresMpsNative(py::module &m) { + m.def( + "get_forward_scores_mps_native", + [](FsaVec &fsas, torch::Tensor sorted_arc_ids, + std::vector batch_sizes, + bool log_semiring) -> torch::Tensor { + DeviceGuard guard(fsas.Context()); + Array1 scores = mps_ops::GetForwardScoresMpsNative( + fsas, sorted_arc_ids, batch_sizes, log_semiring); + return ToTorch(scores); + }, + py::arg("fsas"), py::arg("sorted_arc_ids"), py::arg("batch_sizes"), + py::arg("log_semiring")); +} + +// PybindGetForwardScoresMpsAssocScan: O(log N) associative-scan forward scores. +// +// Uses a Hillis-Steele inclusive prefix scan over per-state transition matrices +// (tropical semiring). Falls back internally to the native sequential path +// when conditions (single FSA, 4 ≤ N ≤ 128, tropical semiring) are not met. +static void PybindGetForwardScoresMpsAssocScan(py::module &m) { + m.def( + "get_forward_scores_mps_assoc_scan", + [](FsaVec &fsas, torch::Tensor sorted_arc_ids, + std::vector batch_sizes, + bool log_semiring) -> torch::Tensor { + DeviceGuard guard(fsas.Context()); + Array1 scores = mps_ops::GetForwardScoresMpsAssocScan( + fsas, sorted_arc_ids, batch_sizes, log_semiring); + return ToTorch(scores); + }, + py::arg("fsas"), py::arg("sorted_arc_ids"), py::arg("batch_sizes"), + py::arg("log_semiring")); +} +#endif // K2_WITH_MPS + template static void PybindBackpropGetForwardScores(py::module &m, const char *name) { // entering_arcs is not empty only if log_semiring is false @@ -687,6 +754,11 @@ void PybindFsa(py::module &m) { k2::PybindFsaBasicProperties(m); k2::PybindGetForwardScores(m, "get_forward_scores_float"); k2::PybindGetForwardScores(m, "get_forward_scores_double"); +#ifdef K2_WITH_MPS + k2::PybindGetForwardScoresMps(m); + k2::PybindGetForwardScoresMpsNative(m); + k2::PybindGetForwardScoresMpsAssocScan(m); +#endif k2::PybindBackpropGetForwardScores( m, "backprop_get_forward_scores_float"); k2::PybindBackpropGetForwardScores( diff --git a/k2/python/csrc/torch/mps_fsa_scores.h b/k2/python/csrc/torch/mps_fsa_scores.h new file mode 100644 index 000000000..8b7d56be7 --- /dev/null +++ b/k2/python/csrc/torch/mps_fsa_scores.h @@ -0,0 +1,95 @@ +/** + * Copyright 2026 k2-fsa Authors + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// MPS-accelerated GetForwardScores for float32 / log semiring. +// Only compiled when K2_WITH_MPS is defined. +// +// The Metal kernel dispatches one thread per entering arc per BFS batch. +// Each thread atomically updates state_scores[dst] via a CAS logadd loop, +// avoiding the intermediate entering_arc_batch_scores array used by the +// CPU sequential path. Sequential BFS ordering is maintained by the +// Metal command queue's in-order execution — no CPU barriers needed. +#pragma once + +#ifdef K2_WITH_MPS + +// k2 headers must be included before Metal/MPS headers to prevent +// TORCH_ASSERT_ONLY_METHOD_OPERATORS conflicts with aten_interned_strings.h. +#include +#include "k2/csrc/array.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/ragged.h" +#include // NOLINT(build/include_order) + +namespace k2 { +namespace mps_ops { + +// GetForwardScoresMps: MPS-native forward pass for float32 (log or tropical). +// +// Dispatches a Metal compute kernel for each BFS batch. The kernel +// evaluates state_scores[dst] = logsumexp(state_scores[src] + arc.score) +// (log semiring) or max(state_scores[dst], state_scores[src] + arc.score) +// (tropical semiring) over all entering arcs in parallel using atomic CAS. +// +// `entering_arc_batches_cpu` must have CPU context — compute it on a CPU copy +// of the FSA. `fsas` must have MPS context. The arc IDs are copied to MPS +// once before the kernel loop. +Array1 GetForwardScoresMps(FsaVec &fsas, + Ragged &entering_arc_batches_cpu, + bool log_semiring); + +// GetForwardScoresMpsNative: zero-copy MPS-native forward pass. +// +// Like GetForwardScoresMps but accepts pre-sorted arc IDs already on MPS and +// a CPU-side batch_sizes vector. This avoids the full FSA CPU copy: the +// caller only transfers arc.dest_state (4 bytes × num_arcs) to CPU for +// sorting, then moves the sorted indices (4 bytes × num_arcs) back to MPS. +// +// `sorted_arc_ids` — int32 MPS tensor: arc indices sorted by BFS level. +// `batch_sizes` — number of arcs per BFS level (may include zeros). +// `fsas` — must have MPS context. +Array1 GetForwardScoresMpsNative( + FsaVec &fsas, + torch::Tensor sorted_arc_ids, + const std::vector &batch_sizes, + bool log_semiring); + +// GetForwardScoresMpsAssocScan: O(log N) associative-scan forward pass. +// +// Uses a Hillis-Steele inclusive prefix scan over N per-state transition +// matrices (N×N each, tropical semiring) to compute all state forward scores +// in ⌈log₂N⌉ Metal encoder calls instead of N sequential calls. +// +// Falls back to GetForwardScoresMpsNative when: +// • log_semiring is true (log semiring not yet implemented) +// • num_fsas != 1 (multi-FSA not supported — would need independent scans) +// • num_states < 4 or > 128 (outside the beneficial range) +// +// `sorted_arc_ids` — int32 MPS tensor: arc indices sorted by dest_state_local. +// `batch_sizes` — number of arcs per dest_state_local value (one per state). +// `fsas` — must have MPS context, exactly 1 FSA. +Array1 GetForwardScoresMpsAssocScan( + FsaVec &fsas, + torch::Tensor sorted_arc_ids, + const std::vector &batch_sizes, + bool log_semiring); + +} // namespace mps_ops +} // namespace k2 + +#endif // K2_WITH_MPS diff --git a/k2/python/csrc/torch/mps_fsa_scores.mm b/k2/python/csrc/torch/mps_fsa_scores.mm new file mode 100644 index 000000000..0c9a76f6e --- /dev/null +++ b/k2/python/csrc/torch/mps_fsa_scores.mm @@ -0,0 +1,793 @@ +/** + * Copyright 2026 k2-fsa Authors + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Objective-C++ — must be compiled with clang as .mm on macOS. + +// k2 headers first: prevents TORCH_ASSERT_ONLY_METHOD_OPERATORS conflict +// with ATen/native/mps/OperationUtils.h (same fix as mutual_information_mps.mm). +#include "k2/python/csrc/torch/mps_fsa_scores.h" +#include "k2/csrc/mps_utils.h" // AsMpsTensor / MpsRegistryView +#include "k2/csrc/ragged_ops.h" // RaggedAxis0Splitter +#include "k2/csrc/pytorch_context.h" // kMps + +#import +#import +#include +#include // getMTLBufferStorage + +#include + +// --------------------------------------------------------------------------- +// Embedded Metal Shading Language kernel source +// --------------------------------------------------------------------------- +// Two kernels: log semiring (logadd) and tropical semiring (max). +// Each thread handles one entering arc in the current BFS batch. +// +// Layout note: FsaArc must match k2's Arc struct exactly (16 bytes): +// { int src_state, int dest_state, int label, float score } +// +// Atomic float CAS pattern: we store float bits in device atomic_int; +// as_type/as_type perform bitwise reinterpretation without +// any value conversion (equivalent to memcpy / union in C). +static const char *kFsaKernelSrc = R"MSL( +#include +using namespace metal; + +// Must match k2::Arc layout: 4 × 4 bytes = 16 bytes, no padding. +struct FsaArc { + int src_state; + int dest_state; + int label; + float score; +}; + +// Numerically stable log(exp(a) + exp(b)) without log1p (not in Metal stdlib). +inline float log_add(float a, float b) { + if (isinf(a) && a < 0.0f) return b; + if (isinf(b) && b < 0.0f) return a; + float hi = max(a, b); + float lo = min(a, b); + // lo - hi in (-inf, 0], so exp(...) in [0, 1], log(1+...) in [0, log2]. + return hi + log(1.0f + exp(lo - hi)); +} + +// --------------------------------------------------------------------------- +// fsa_forward_log: log semiring (logsumexp) forward pass. +// +// For each entering arc in the current BFS batch (one thread per arc): +// candidate = state_scores[src] + arc.score +// state_scores[dst] = log(exp(state_scores[dst]) + exp(candidate)) +// +// state_scores_in and state_scores_out alias the same buffer; reads are from +// previously-committed states (BFS layers guarantee disjoint src/dst sets). +// --------------------------------------------------------------------------- +kernel void fsa_forward_log( + device const int* entering_arc_ids [[buffer(0)]], // batch arc indices + device const FsaArc* arcs [[buffer(1)]], // all FSA arcs + device const int* arc_to_src [[buffer(2)]], // fsas row_ids2 (global src state) + device const float* scores_in [[buffer(3)]], // state_scores (read) + device float* scores_out [[buffer(4)]], // state_scores (write) + constant int& n_arcs [[buffer(5)]], + uint gid [[thread_position_in_grid]]) +{ + if ((int)gid >= n_arcs) return; + + int arc_idx = entering_arc_ids[gid]; + int src = arc_to_src[arc_idx]; // global source state + // dest_state in Arc is LOCAL; convert to global via the FSA's state offset. + // offset = global_src - local_src = arc_to_src[arc_idx] - arcs[arc_idx].src_state + int dst = arcs[arc_idx].dest_state + (src - arcs[arc_idx].src_state); + float arc_w = arcs[arc_idx].score; + float src_s = scores_in[src]; + + if (isinf(src_s) && src_s < 0.0f) return; // -inf source: no contribution + + float candidate = src_s + arc_w; + if (isnan(candidate)) return; // guard: NaN arc scores must not spin the CAS + + // Atomic logadd: CAS loop over bit pattern of the float. + device atomic_int* slot = + reinterpret_cast(scores_out + dst); + int old_bits = atomic_load_explicit(slot, memory_order_relaxed); + int new_bits; + do { + float old_val = as_type(old_bits); + float new_val = log_add(old_val, candidate); + new_bits = as_type(new_val); + if (old_bits == new_bits) return; // no change (convergence guard) + } while (!atomic_compare_exchange_weak_explicit( + slot, &old_bits, new_bits, + memory_order_relaxed, memory_order_relaxed)); +} + +// --------------------------------------------------------------------------- +// fsa_forward_tropical: tropical semiring (max) forward pass. +// +// candidate = state_scores[src] + arc.score +// state_scores[dst] = max(state_scores[dst], candidate) +// --------------------------------------------------------------------------- +kernel void fsa_forward_tropical( + device const int* entering_arc_ids [[buffer(0)]], + device const FsaArc* arcs [[buffer(1)]], + device const int* arc_to_src [[buffer(2)]], // global src state + device const float* scores_in [[buffer(3)]], + device float* scores_out [[buffer(4)]], + constant int& n_arcs [[buffer(5)]], + uint gid [[thread_position_in_grid]]) +{ + if ((int)gid >= n_arcs) return; + + int arc_idx = entering_arc_ids[gid]; + int src = arc_to_src[arc_idx]; + // dest_state in Arc is LOCAL; convert to global via FSA offset. + int dst = arcs[arc_idx].dest_state + (src - arcs[arc_idx].src_state); + float arc_w = arcs[arc_idx].score; + float src_s = scores_in[src]; + + if (isinf(src_s) && src_s < 0.0f) return; + + float candidate = src_s + arc_w; + if (isnan(candidate)) return; // guard: NaN arc scores must not spin the CAS + + // Atomic max via CAS (float comparison, not int comparison). + device atomic_int* slot = + reinterpret_cast(scores_out + dst); + int cand_bits = as_type(candidate); + int old_bits = atomic_load_explicit(slot, memory_order_relaxed); + while (true) { + float old_val = as_type(old_bits); + if (old_val >= candidate) return; // already at least as large + if (atomic_compare_exchange_weak_explicit( + slot, &old_bits, cand_bits, + memory_order_relaxed, memory_order_relaxed)) + return; // CAS succeeded; old_bits updated on failure → retry + } +} +// =========================================================================== +// Priority 6 — Associative-scan (Hillis-Steele) prefix kernels +// =========================================================================== +// For a single-FSA FsaVec with N states, the BFS-level batches have the +// structure: batch_sizes[d] = #arcs entering dest_state_local == d. Each +// state d corresponds to exactly one matrix M[d] of size N×N in the tropical +// semiring (max-plus). A Hillis-Steele inclusive prefix scan over M[0..N-1] +// gives all prefix products P[s] = M[s] ⊗ … ⊗ M[0] in ⌈log₂N⌉ steps rather +// than N sequential encoder calls. State score: α[s] = P[s][s][0]. +// +// M[d] semantics (tropical): +// row == col && col != d → 0.0 (pass-through for other states) +// row == d → max arc weight for arcs src→d (or -inf if none) +// else → -INFINITY +// M[0] (start state, never entered): identity (row == col → 0, else -inf). + +// assoc_scan_init: initialize T_pow2 matrices of size N×N. +// mat < t_actual: actual level d — pass-through rows except row==d (all -inf) +// Special case d==0: identity. +// mat >= t_actual: padding — identity. +kernel void assoc_scan_init( + device float* M [[buffer(0)]], // [T_pow2 × N × N] + constant int& N [[buffer(1)]], + constant int& t_pow2 [[buffer(2)]], + constant int& t_actual [[buffer(3)]], + uint3 gid [[thread_position_in_grid]]) +{ + int mat = (int)gid.z; + int row = (int)gid.y; + int col = (int)gid.x; + if (mat >= t_pow2 || row >= N || col >= N) return; + float val; + if (mat >= t_actual || mat == 0) { + // Padding or start state: identity + val = (row == col) ? 0.0f : -INFINITY; + } else { + // Actual level d: pass-through for states != d, -inf for state d row + val = (row == col && row != mat) ? 0.0f : -INFINITY; + } + M[mat * N * N + row * N + col] = val; +} + +// assoc_scan_build_level: write arc weights into matrix M_level (= M[d]). +// Uses atomic CAS max so concurrent arcs entering the same (dst,src) pair +// correctly keep the largest weight. +kernel void assoc_scan_build_level( + device const int* arc_ids [[buffer(0)]], + device const FsaArc* arcs [[buffer(1)]], + device float* M_level [[buffer(2)]], // [N × N] slice for level d + constant int& n_arcs [[buffer(3)]], + constant int& N [[buffer(4)]], + uint gid [[thread_position_in_grid]]) +{ + if ((int)gid >= n_arcs) return; + int arc_idx = arc_ids[gid]; + int src = arcs[arc_idx].src_state; + int dst = arcs[arc_idx].dest_state; + float w = arcs[arc_idx].score; + if (isnan(w)) return; // guard: NaN arc scores must not spin the CAS + device atomic_int* slot = + reinterpret_cast(M_level + dst * N + src); + int w_bits = as_type(w); + int old_bits = atomic_load_explicit(slot, memory_order_relaxed); + while (true) { + if (as_type(old_bits) >= w) return; + if (atomic_compare_exchange_weak_explicit( + slot, &old_bits, w_bits, + memory_order_relaxed, memory_order_relaxed)) + return; + } +} + +// assoc_scan_prefix_step: one Hillis-Steele step (tropical semiring). +// For mat >= step_d: buf_out[mat] = buf_in[mat] ⊗ buf_in[mat - step_d] +// For mat < step_d: buf_out[mat] = buf_in[mat] (copy) +// Grid: (N, N, T_pow2) +kernel void assoc_scan_prefix_step( + device const float* buf_in [[buffer(0)]], + device float* buf_out [[buffer(1)]], + constant int& N [[buffer(2)]], + constant int& t_pow2 [[buffer(3)]], + constant int& step_d [[buffer(4)]], + uint3 gid [[thread_position_in_grid]]) +{ + int mat = (int)gid.z; + int row = (int)gid.y; + int col = (int)gid.x; + if (mat >= t_pow2 || row >= N || col >= N) return; + int base = mat * N * N + row * N + col; + if (mat < step_d) { + buf_out[base] = buf_in[base]; + return; + } + // B[mat][row][col] = max_k A[mat][row][k] + A[mat-step_d][k][col] + float best = -INFINITY; + for (int k = 0; k < N; k++) { + float a = buf_in[mat * N * N + row * N + k]; + float b = buf_in[(mat - step_d) * N * N + k * N + col]; + if (!isinf(a) && !isinf(b)) + best = max(best, a + b); + } + buf_out[base] = best; +} + +// assoc_scan_extract: alpha[s] = prefix_buf[s][s][0] for s = 0 .. N-1. +kernel void assoc_scan_extract( + device const float* prefix_buf [[buffer(0)]], // [T_pow2 × N × N] + device float* alpha [[buffer(1)]], // [N] + constant int& N [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + int s = (int)gid; + if (s >= N) return; + alpha[s] = prefix_buf[s * N * N + s * N + 0]; +} + +)MSL"; + +// --------------------------------------------------------------------------- +// Pipeline cache +// --------------------------------------------------------------------------- +struct FsaPipelines { + // __unsafe_unretained: skip ARC release on static destructor so we don't + // send messages to Metal objects after PyTorch has torn down the device. + __unsafe_unretained id log_fwd = nil; + __unsafe_unretained id tropical_fwd = nil; + __unsafe_unretained id assoc_init = nil; + __unsafe_unretained id assoc_build_level = nil; + __unsafe_unretained id assoc_prefix_step = nil; + __unsafe_unretained id assoc_extract = nil; +}; + +static FsaPipelines GetOrBuildFsaPipelines(id device) { + static FsaPipelines cache; + static dispatch_once_t once; + dispatch_once(&once, ^{ + NSError *err = nil; + NSString *src = [NSString stringWithUTF8String:kFsaKernelSrc]; + id lib = + [device newLibraryWithSource:src options:nil error:&err]; + K2_CHECK(lib != nil) << "Metal FSA library compile error: " + << [[err localizedDescription] UTF8String]; + + auto make_pipeline = [&](const char *name) -> id { + id fn = + [lib newFunctionWithName:[NSString stringWithUTF8String:name]]; + K2_CHECK(fn != nil) << "Metal function not found: " << name; + NSError *e2 = nil; + id ps = + [device newComputePipelineStateWithFunction:fn error:&e2]; + K2_CHECK(ps != nil) << "Pipeline state error for " << name << ": " + << [[e2 localizedDescription] UTF8String]; + return ps; + }; + + cache.log_fwd = make_pipeline("fsa_forward_log"); + cache.tropical_fwd = make_pipeline("fsa_forward_tropical"); + cache.assoc_init = make_pipeline("assoc_scan_init"); + cache.assoc_build_level = make_pipeline("assoc_scan_build_level"); + cache.assoc_prefix_step = make_pipeline("assoc_scan_prefix_step"); + cache.assoc_extract = make_pipeline("assoc_scan_extract"); + }); + return cache; +} + +// --------------------------------------------------------------------------- +// Helper: encode one batch kernel onto the current command buffer. +// --------------------------------------------------------------------------- +static void EncodeFsaForwardBatch( + id cmd_buf, + id pipeline, + // MTL buffers + byte offsets + id buf_arc_ids, NSUInteger off_arc_ids, + id buf_arcs, NSUInteger off_arcs, + id buf_src_ids, NSUInteger off_src_ids, + id buf_scores, NSUInteger off_scores, + int32_t n_arcs) +{ + id enc = [cmd_buf computeCommandEncoder]; + K2_CHECK(enc != nil) << "Failed to create MTLComputeCommandEncoder"; + + [enc setComputePipelineState:pipeline]; + [enc setBuffer:buf_arc_ids offset:off_arc_ids atIndex:0]; + [enc setBuffer:buf_arcs offset:off_arcs atIndex:1]; + [enc setBuffer:buf_src_ids offset:off_src_ids atIndex:2]; + [enc setBuffer:buf_scores offset:off_scores atIndex:3]; // read + [enc setBuffer:buf_scores offset:off_scores atIndex:4]; // write (same) + [enc setBytes:&n_arcs length:sizeof(int32_t) atIndex:5]; + + // 256 threads per threadgroup — works well for arc-level parallelism. + static const NSUInteger kThreadsPerGroup = 256; + NSUInteger num_groups = + ((NSUInteger)n_arcs + kThreadsPerGroup - 1) / kThreadsPerGroup; + + [enc dispatchThreadgroups:MTLSizeMake(num_groups, 1, 1) + threadsPerThreadgroup:MTLSizeMake(kThreadsPerGroup, 1, 1)]; + [enc endEncoding]; +} + +// --------------------------------------------------------------------------- +// GetForwardScoresMps — public entry point +// --------------------------------------------------------------------------- +namespace k2 { +namespace mps_ops { + +Array1 GetForwardScoresMps(FsaVec &fsas, + Ragged &entering_arc_batches_cpu, + bool log_semiring) { + ContextPtr &c = fsas.Context(); + K2_CHECK_EQ(c->GetDeviceType(), kMps); + + int32_t num_fsas = fsas.Dim0(); + int32_t num_states = fsas.TotSize(1); + int32_t num_arcs = fsas.TotSize(2); + int32_t num_batches = entering_arc_batches_cpu.Dim0(); + int32_t total_arc_ids = entering_arc_batches_cpu.NumElements(); + + const int32_t *fsa_row_splits1 = fsas.RowSplits(1).Data(); // MPS ptr + const int32_t *fsas_row_ids2 = fsas.RowIds(2).Data(); // MPS ptr + const Arc *arcs_ptr = fsas.values.Data(); // MPS ptr + + // ------------------------------------------------------------------ + // 1. Allocate state_scores on MPS and initialise via ATen ops. + // Array1(c, n, val) uses K2_EVAL → CPU sequential for MPS; + // we skip that and do the fill directly through ATen. + // ------------------------------------------------------------------ + Array1 state_scores(c, num_states); // uninitialized allocation + + auto scores_t = AsMpsTensor(state_scores.Data(), + (int64_t)num_states, torch::kFloat); + scores_t.fill_(-std::numeric_limits::infinity()); + + // Set start state of each non-empty FSA to 0. + // row_splits1[i] is the global state index of FSA i's start state. + if (num_fsas > 0) { + auto row_splits_t = AsMpsTensor(fsa_row_splits1, + (int64_t)(num_fsas + 1)); // int32 + auto starts = row_splits_t.slice(0, 0, num_fsas); + auto ends = row_splits_t.slice(0, 1, num_fsas + 1); + auto nonempty = starts.ne(ends); + auto valid = starts.masked_select(nonempty).to(torch::kInt64); + scores_t.index_put_({valid}, 0.0f); + } + + if (num_arcs == 0 || num_batches == 0 || total_arc_ids == 0) + return state_scores; + + // ------------------------------------------------------------------ + // 2. Obtain Metal device, build (or reuse) pipelines. + // ------------------------------------------------------------------ + id device = MTLCreateSystemDefaultDevice(); + K2_CHECK(device != nil) << "No Metal device available"; + FsaPipelines pipelines = GetOrBuildFsaPipelines(device); + + // ------------------------------------------------------------------ + // 3. Get persistent MTL buffers for the MPS-resident arrays. + // ------------------------------------------------------------------ + // Arc structs: view as bytes so offset is in bytes directly. + auto arcs_t = AsMpsTensor(arcs_ptr, + (int64_t)num_arcs * (int64_t)sizeof(Arc), + torch::kByte); + id buf_arcs = at::native::mps::getMTLBufferStorage(arcs_t); + NSUInteger off_arcs = (NSUInteger)(arcs_t.storage_offset()); + + auto src_t = AsMpsTensor(fsas_row_ids2, (int64_t)num_arcs); // int32 + id buf_src = at::native::mps::getMTLBufferStorage(src_t); + NSUInteger off_src = (NSUInteger)(src_t.storage_offset() * 4); + + id buf_scores = at::native::mps::getMTLBufferStorage(scores_t); + NSUInteger off_scores = (NSUInteger)(scores_t.storage_offset() * 4); + + // ------------------------------------------------------------------ + // 4. Copy all CPU entering-arc IDs to MPS once. + // RaggedAxis0Splitter (CPU context) gives per-batch arc_begin offsets + // into this flat array, so we index into it by byte offset. + // ------------------------------------------------------------------ + const int32_t *cpu_arc_ids_ptr = entering_arc_batches_cpu.values.Data(); + torch::Tensor arc_ids_mps = + torch::from_blob((void *)cpu_arc_ids_ptr, {(int64_t)total_arc_ids}, + torch::TensorOptions() + .dtype(torch::kInt32) + .device(torch::kCPU)) + .to(at::Device(at::kMPS)); + id buf_all_arc_ids = + at::native::mps::getMTLBufferStorage(arc_ids_mps); + NSUInteger base_off_arc_ids = + (NSUInteger)(arc_ids_mps.storage_offset() * 4); + + id pipeline = + log_semiring ? pipelines.log_fwd : pipelines.tropical_fwd; + + // ------------------------------------------------------------------ + // 5. Flush any pending PyTorch MPS encoder, then encode all batch + // kernels onto a single command buffer. Metal executes them in + // submission order, which enforces BFS layer dependencies. + // ------------------------------------------------------------------ + auto *stream = at::mps::getCurrentMPSStream(); + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + + id cmd_buf = stream->commandBuffer(); + K2_CHECK(cmd_buf != nil) << "Failed to get MPS command buffer"; + + // entering_arc_batches_cpu has CPU context: RaggedAxis0Splitter works. + RaggedAxis0Splitter splitter(entering_arc_batches_cpu); + + for (int32_t i = 0; i < num_batches; ++i) { + int32_t arc_begin; + Ragged batch = splitter.GetElement(i, &arc_begin); + int32_t n_arcs_batch = batch.NumElements(); + if (n_arcs_batch == 0) continue; + + // Byte offset into arc_ids_mps for this batch. + NSUInteger off_arc_ids = base_off_arc_ids + (NSUInteger)(arc_begin * 4); + + EncodeFsaForwardBatch( + cmd_buf, pipeline, + buf_all_arc_ids, off_arc_ids, + buf_arcs, off_arcs, + buf_src, off_src, + buf_scores, off_scores, + n_arcs_batch); + } + + // Commit asynchronously; PyTorch will sync when the result is needed. + stream->synchronize(at::mps::SyncType::NONE); + + return state_scores; +} + +// --------------------------------------------------------------------------- +// GetForwardScoresMpsNative — zero-copy variant. +// +// Accepts sorted_arc_ids already resident on MPS: avoids the full FSA +// CPU copy required by GetForwardScoresMps. The caller provides a CPU-side +// batch_sizes vector (one int per BFS level) derived cheaply from the +// arc.dest_state_local column, exploiting the k2 invariant that FSAs are +// topologically sorted (src_state_local < dest_state_local for every arc). +// --------------------------------------------------------------------------- +Array1 GetForwardScoresMpsNative( + FsaVec &fsas, + torch::Tensor sorted_arc_ids, + const std::vector &batch_sizes, + bool log_semiring) { + ContextPtr &c = fsas.Context(); + K2_CHECK_EQ(c->GetDeviceType(), kMps); + + int32_t num_fsas = fsas.Dim0(); + int32_t num_states = fsas.TotSize(1); + int32_t num_arcs = fsas.TotSize(2); + + const int32_t *fsa_row_splits1 = fsas.RowSplits(1).Data(); // MPS ptr + const int32_t *fsas_row_ids2 = fsas.RowIds(2).Data(); // MPS ptr + const Arc *arcs_ptr = fsas.values.Data(); // MPS ptr + + // ------------------------------------------------------------------ + // 1. Allocate state_scores on MPS and initialise via ATen ops. + // ------------------------------------------------------------------ + Array1 state_scores(c, num_states); + auto scores_t = AsMpsTensor(state_scores.Data(), + (int64_t)num_states, torch::kFloat); + scores_t.fill_(-std::numeric_limits::infinity()); + + if (num_fsas > 0) { + auto row_splits_t = AsMpsTensor(fsa_row_splits1, + (int64_t)(num_fsas + 1)); + auto starts = row_splits_t.slice(0, 0, num_fsas); + auto ends = row_splits_t.slice(0, 1, num_fsas + 1); + auto nonempty = starts.ne(ends); + auto valid = starts.masked_select(nonempty).to(torch::kInt64); + scores_t.index_put_({valid}, 0.0f); + } + + int32_t total_arc_ids = (int32_t)sorted_arc_ids.numel(); + if (total_arc_ids == 0 || batch_sizes.empty()) + return state_scores; + + // ------------------------------------------------------------------ + // 2. Build (or reuse) Metal pipelines. + // ------------------------------------------------------------------ + id device = MTLCreateSystemDefaultDevice(); + K2_CHECK(device != nil) << "No Metal device available"; + FsaPipelines pipelines = GetOrBuildFsaPipelines(device); + + // ------------------------------------------------------------------ + // 3. Get persistent MTL buffers for the MPS-resident arrays. + // ------------------------------------------------------------------ + auto arcs_t = AsMpsTensor(arcs_ptr, + (int64_t)num_arcs * (int64_t)sizeof(Arc), + torch::kByte); + id buf_arcs = at::native::mps::getMTLBufferStorage(arcs_t); + NSUInteger off_arcs = (NSUInteger)(arcs_t.storage_offset()); + + auto src_t = AsMpsTensor(fsas_row_ids2, (int64_t)num_arcs); + id buf_src = at::native::mps::getMTLBufferStorage(src_t); + NSUInteger off_src = (NSUInteger)(src_t.storage_offset() * 4); + + id buf_scores = at::native::mps::getMTLBufferStorage(scores_t); + NSUInteger off_scores = (NSUInteger)(scores_t.storage_offset() * 4); + + // ------------------------------------------------------------------ + // 4. sorted_arc_ids is already on MPS — obtain buffer directly. + // No CPU→MPS copy needed (this is the key gain over GetForwardScoresMps). + // ------------------------------------------------------------------ + id buf_arc_ids = + at::native::mps::getMTLBufferStorage(sorted_arc_ids); + NSUInteger base_off_arc_ids = + (NSUInteger)(sorted_arc_ids.storage_offset() * 4); + + id pipeline = + log_semiring ? pipelines.log_fwd : pipelines.tropical_fwd; + + // ------------------------------------------------------------------ + // 5. Flush any pending PyTorch MPS encoder, then encode all batch + // kernels onto a single command buffer. + // ------------------------------------------------------------------ + auto *stream = at::mps::getCurrentMPSStream(); + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + + id cmd_buf = stream->commandBuffer(); + K2_CHECK(cmd_buf != nil) << "Failed to get MPS command buffer"; + + int32_t arc_cursor = 0; + for (int32_t n_arcs_batch : batch_sizes) { + if (n_arcs_batch > 0) { + NSUInteger off_arc_ids = + base_off_arc_ids + (NSUInteger)(arc_cursor * 4); + EncodeFsaForwardBatch( + cmd_buf, pipeline, + buf_arc_ids, off_arc_ids, + buf_arcs, off_arcs, + buf_src, off_src, + buf_scores, off_scores, + n_arcs_batch); + } + arc_cursor += n_arcs_batch; + } + + stream->synchronize(at::mps::SyncType::NONE); + return state_scores; +} + +// --------------------------------------------------------------------------- +// GetForwardScoresMpsAssocScan — O(log N) associative-scan forward pass. +// +// Applies a Hillis-Steele inclusive prefix scan over N per-state transition +// matrices (one N×N matrix per dest_state_local), reducing the number of +// Metal compute encoder calls from N to ⌈log₂(N)⌉. +// +// Conditions for using this path (falls back to GetForwardScoresMpsNative): +// • num_fsas == 1 (single FSA — multi-FSA requires independent scans) +// • 4 ≤ num_states ≤ 128 (dense matrix fits GPU cache; below 4 sequential wins) +// • !log_semiring (tropical / Viterbi only for now) +// +// Memory: two ping-pong float buffers of shape [T_pow2 × N × N] on MPS, +// where T_pow2 = next power of 2 ≥ N. Max: 2 × 128 × 128² × 4 = 16 MB. +// --------------------------------------------------------------------------- +Array1 GetForwardScoresMpsAssocScan( + FsaVec &fsas, + torch::Tensor sorted_arc_ids, + const std::vector &batch_sizes, + bool log_semiring) +{ + ContextPtr &c = fsas.Context(); + K2_CHECK_EQ(c->GetDeviceType(), kMps); + + int32_t num_fsas = fsas.Dim0(); + int32_t num_states = fsas.TotSize(1); + int32_t num_arcs = fsas.TotSize(2); + + // --- threshold check — fall back to native sequential path --- + if (log_semiring || num_fsas != 1 || num_states < 4 || num_states > 128) { + return GetForwardScoresMpsNative( + fsas, sorted_arc_ids, batch_sizes, log_semiring); + } + + const int32_t N = num_states; + + // T_pow2: next power of 2 >= N (Hillis-Steele requires power-of-2 array). + int32_t T_pow2 = 1; + while (T_pow2 < N) T_pow2 <<= 1; + + const int32_t *fsa_row_splits1 = fsas.RowSplits(1).Data(); // MPS ptr + const Arc *arcs_ptr = fsas.values.Data(); // MPS ptr + + // ------------------------------------------------------------------ + // 1. Allocate two ping-pong buffers on MPS (T_pow2 × N × N each). + // ------------------------------------------------------------------ + int64_t mat_elems = (int64_t)T_pow2 * N * N; + torch::Tensor buf_a = torch::empty( + {mat_elems}, torch::TensorOptions().dtype(torch::kFloat).device(at::kMPS)); + torch::Tensor buf_b = torch::empty_like(buf_a); + + auto arcs_t = AsMpsTensor(arcs_ptr, + (int64_t)num_arcs * (int64_t)sizeof(Arc), + torch::kByte); + + id device = MTLCreateSystemDefaultDevice(); + K2_CHECK(device != nil) << "No Metal device available"; + FsaPipelines pipes = GetOrBuildFsaPipelines(device); + + // ------------------------------------------------------------------ + // 2. Flush any pending ATen encoder, then begin encoding. + // ------------------------------------------------------------------ + auto *stream = at::mps::getCurrentMPSStream(); + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + id cmd = stream->commandBuffer(); + K2_CHECK(cmd != nil) << "Failed to get MPS command buffer"; + + id mtl_a = at::native::mps::getMTLBufferStorage(buf_a); + id mtl_b = at::native::mps::getMTLBufferStorage(buf_b); + id buf_arcs = at::native::mps::getMTLBufferStorage(arcs_t); + NSUInteger off_arcs = (NSUInteger)arcs_t.storage_offset(); // byte offset + + id buf_arc_ids = + at::native::mps::getMTLBufferStorage(sorted_arc_ids); + NSUInteger base_off_arc_ids = + (NSUInteger)(sorted_arc_ids.storage_offset() * 4); + + // Helper: dispatch a 3-D grid (X × Y × Z) threadgroups of 1 thread each. + // For small N this is fine; threadgroup size = 1 avoids occupancy issues. + auto dispatch3 = [&](id pso, + NSUInteger X, NSUInteger Y, NSUInteger Z, + /* buffer bindings set by caller */ int) { + // caller sets buffers before calling this lambda — not feasible as lambda. + // Instead, inline the dispatch pattern at each call site below. + (void)pso; (void)X; (void)Y; (void)Z; + }; + (void)dispatch3; // suppress unused warning; we inline below. + + // ------------------------------------------------------------------ + // 3. assoc_scan_init: fill buf_a (both buffers, but only buf_a matters + // since buf_b gets fully overwritten by the first prefix step). + // ------------------------------------------------------------------ + { + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pipes.assoc_init]; + [enc setBuffer:mtl_a offset:(NSUInteger)(buf_a.storage_offset() * 4) atIndex:0]; + int32_t n_val = N, tp_val = T_pow2, ta_val = N; // t_actual = N states + [enc setBytes:&n_val length:4 atIndex:1]; + [enc setBytes:&tp_val length:4 atIndex:2]; + [enc setBytes:&ta_val length:4 atIndex:3]; + // Grid: (N, N, T_pow2) with threadgroup size (1,1,1). + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)N, (NSUInteger)N, (NSUInteger)T_pow2) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [enc endEncoding]; + } + + // ------------------------------------------------------------------ + // 4. assoc_scan_build_level: fill arc weights into each matrix M[d]. + // sorted_arc_ids is already on MPS (zero-copy from Priority 4). + // ------------------------------------------------------------------ + int32_t arc_cursor = 0; + for (int32_t d = 0; d < N; ++d) { + int32_t n_arcs_d = (d < (int32_t)batch_sizes.size()) ? batch_sizes[d] : 0; + if (n_arcs_d > 0) { + NSUInteger off_arc_ids_d = base_off_arc_ids + (NSUInteger)(arc_cursor * 4); + // M[d] starts at buf_a offset: d * N * N floats + NSUInteger off_m_d = (NSUInteger)(buf_a.storage_offset() * 4) + + (NSUInteger)(d * N * N * 4); + + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pipes.assoc_build_level]; + [enc setBuffer:buf_arc_ids offset:off_arc_ids_d atIndex:0]; + [enc setBuffer:buf_arcs offset:off_arcs atIndex:1]; + [enc setBuffer:mtl_a offset:off_m_d atIndex:2]; + int32_t n_val = N; + [enc setBytes:&n_arcs_d length:4 atIndex:3]; + [enc setBytes:&n_val length:4 atIndex:4]; + NSUInteger tg = (NSUInteger)n_arcs_d; + [enc dispatchThreadgroups:MTLSizeMake(tg, 1, 1) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [enc endEncoding]; + } + arc_cursor += n_arcs_d; + } + + // ------------------------------------------------------------------ + // 5. Hillis-Steele prefix scan: log2(T_pow2) steps. + // Ping-pong between buf_a (read) and buf_b (write), then swap. + // ------------------------------------------------------------------ + id cur_in = mtl_a; + id cur_out = mtl_b; + NSUInteger in_base = (NSUInteger)(buf_a.storage_offset() * 4); + NSUInteger out_base = (NSUInteger)(buf_b.storage_offset() * 4); + + for (int32_t step_d = 1; step_d < T_pow2; step_d <<= 1) { + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pipes.assoc_prefix_step]; + [enc setBuffer:cur_in offset:in_base atIndex:0]; + [enc setBuffer:cur_out offset:out_base atIndex:1]; + int32_t n_val = N, tp_val = T_pow2, sd_val = step_d; + [enc setBytes:&n_val length:4 atIndex:2]; + [enc setBytes:&tp_val length:4 atIndex:3]; + [enc setBytes:&sd_val length:4 atIndex:4]; + // Grid: (N, N, T_pow2) — each thread handles one cell of one matrix. + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)N, + (NSUInteger)N, + (NSUInteger)T_pow2) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [enc endEncoding]; + // Swap ping-pong. + std::swap(cur_in, cur_out); + std::swap(in_base, out_base); + } + // After the loop, cur_in holds the final prefix products. + + // ------------------------------------------------------------------ + // 6. Extract state scores: alpha[s] = prefix[s][s][0]. + // ------------------------------------------------------------------ + Array1 state_scores(c, num_states); + auto scores_t = AsMpsTensor(state_scores.Data(), + (int64_t)num_states, torch::kFloat); + id buf_scores = at::native::mps::getMTLBufferStorage(scores_t); + NSUInteger off_scores = (NSUInteger)(scores_t.storage_offset() * 4); + + { + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pipes.assoc_extract]; + [enc setBuffer:cur_in offset:in_base atIndex:0]; + [enc setBuffer:buf_scores offset:off_scores atIndex:1]; + int32_t n_val = N; + [enc setBytes:&n_val length:4 atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake((NSUInteger)N, 1, 1) + threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [enc endEncoding]; + } + + stream->synchronize(at::mps::SyncType::NONE); + return state_scores; +} + +} // namespace mps_ops +} // namespace k2 diff --git a/k2/python/csrc/torch/mutual_information.cu b/k2/python/csrc/torch/mutual_information.cu index 23e290610..3da5697a9 100644 --- a/k2/python/csrc/torch/mutual_information.cu +++ b/k2/python/csrc/torch/mutual_information.cu @@ -29,8 +29,24 @@ void PybindMutualInformation(py::module &m) { torch::optional boundary, torch::Tensor p) -> torch::Tensor { k2::DeviceGuard guard(k2::GetContext(px)); + auto orig_device = px.device(); if (px.device().is_cpu()) { return k2::MutualInformationCpu(px, py, boundary, p); + } else if (px.device().type() == torch::kMPS) { +#ifdef K2_WITH_MPS + // Only float32 is supported natively; fall back to CPU for double. + if (px.scalar_type() == torch::kFloat) { + return k2::MutualInformationMps(px, py, boundary, p); + } +#endif + // CPU fallback for MPS double (or no-MPS build) + auto px_cpu = px.cpu(), py_cpu = py.cpu(), p_cpu = p.cpu(); + torch::optional boundary_cpu; + if (boundary.has_value()) boundary_cpu = boundary.value().cpu(); + auto result = k2::MutualInformationCpu(px_cpu, py_cpu, boundary_cpu, + p_cpu); + p.copy_(p_cpu.to(orig_device)); + return result.to(orig_device); } else { #ifdef K2_WITH_CUDA return k2::MutualInformationCuda(px, py, boundary, p); @@ -49,9 +65,27 @@ void PybindMutualInformation(py::module &m) { torch::optional boundary, torch::Tensor p, torch::Tensor ans_grad) -> std::vector { k2::DeviceGuard guard(k2::GetContext(px)); + auto orig_device = px.device(); if (px.device().is_cpu()) { return k2::MutualInformationBackwardCpu(px, py, boundary, p, - ans_grad); + ans_grad); + } else if (px.device().type() == torch::kMPS) { +#ifdef K2_WITH_MPS + if (px.scalar_type() == torch::kFloat) { + return k2::MutualInformationBackwardMps(px, py, boundary, p, + ans_grad, false); + } +#endif + // CPU fallback for MPS double (or no-MPS build) + auto px_cpu = px.cpu(), py_cpu = py.cpu(), p_cpu = p.cpu(); + auto ans_grad_cpu = ans_grad.cpu(); + torch::optional boundary_cpu; + if (boundary.has_value()) boundary_cpu = boundary.value().cpu(); + auto grads = k2::MutualInformationBackwardCpu(px_cpu, py_cpu, + boundary_cpu, p_cpu, + ans_grad_cpu); + for (auto &g : grads) g = g.to(orig_device); + return grads; } else { #ifdef K2_WITH_CUDA return k2::MutualInformationBackwardCuda(px, py, boundary, p, diff --git a/k2/python/csrc/torch/mutual_information.h b/k2/python/csrc/torch/mutual_information.h index b01a9d4f5..ef11de871 100644 --- a/k2/python/csrc/torch/mutual_information.h +++ b/k2/python/csrc/torch/mutual_information.h @@ -105,6 +105,18 @@ std::vector MutualInformationBackwardCuda( torch::Tensor px, torch::Tensor py, torch::optional boundary, torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad); +#ifdef K2_WITH_MPS +torch::Tensor MutualInformationMps( + torch::Tensor px, torch::Tensor py, + torch::optional boundary, + torch::Tensor p); + +std::vector MutualInformationBackwardMps( + torch::Tensor px, torch::Tensor py, + torch::optional boundary, + torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad); +#endif // K2_WITH_MPS + } // namespace k2 void PybindMutualInformation(py::module &m); diff --git a/k2/python/csrc/torch/mutual_information_mps.mm b/k2/python/csrc/torch/mutual_information_mps.mm new file mode 100644 index 000000000..2a43e33f1 --- /dev/null +++ b/k2/python/csrc/torch/mutual_information_mps.mm @@ -0,0 +1,710 @@ +/** + * Copyright 2026 k2-fsa Authors + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Native Metal implementation of mutual_information forward and backward. +// Mirrors the blocked antidiagonal wavefront pattern of mutual_information_cuda.cu. +// Only compiled when K2_WITH_MPS is defined (Apple + MPS build). + +#ifdef K2_WITH_MPS + +// mutual_information.h pulls in torch/extension.h which includes +// aten_interned_strings.h. OperationUtils.h defines +// TORCH_ASSERT_ONLY_METHOD_OPERATORS which would trigger a #error in +// aten_interned_strings.h if included first. So mutual_information.h +// must come before the MPS headers. +#include "k2/python/csrc/torch/mutual_information.h" + +#import +#import + +#include +#include + +// ───────────────────────────────────────────────────────────────────────────── +// Metal Shading Language kernel source (embedded, runtime-compiled once) +// ───────────────────────────────────────────────────────────────────────────── +// +// Design mirrors mutual_information_cuda.cu: +// • BLOCK_SIZE 32 tiles on the (s,t) grid +// • Outer diagonal loop driven from C++ (one Metal dispatch per `iter`) +// • 128 threads per threadgroup; cooperative load / scatter +// • Inner antidiagonal computed by the first SIMD group (threads 0–31) +// • threadgroup_barrier / simdgroup_barrier replace __syncthreads / __syncwarp +// ───────────────────────────────────────────────────────────────────────────── +static const char kMIKernelSrc[] = R"MSL( +#include +using namespace metal; + +// ── helpers ────────────────────────────────────────────────────────────────── +inline float log_add(float a, float b) { + if (a == -INFINITY) return b; + if (b == -INFINITY) return a; + float hi = max(a, b); + // exp(lo - hi) is in (0, 1] so log(1 + exp(...)) is numerically stable + return hi + log(1.0f + exp(min(a, b) - hi)); +} + +inline float safe_exp_mps(float x) { + if (isnan(x) || isinf(x)) return 0.0f; + float r = exp(x); + return (isnan(r) || isinf(r)) ? 0.0f : r; +} + +// ── BLOCK_SIZE must match the C++ constant (32) ─────────────────────────────── +constant int BSIZE = 32; + +// ── Parameter structs (must exactly match the C++ side) ────────────────────── +struct MIFwdParams { + int B, S, T; + int px_stride_b, px_stride_s; // px strides; t stride == 1 + int py_stride_b, py_stride_s; // py strides; t stride == 1 + int p_stride_b, p_stride_s; // p strides; t stride == 1 + int iter; + int num_blocks_this_iter; + int t_offset; // 0 if !modified, -1 if modified +}; + +struct MIBwdParams { + int B, S, T; + int px_stride_b, px_stride_s; + int py_stride_b, py_stride_s; + int p_stride_b, p_stride_s; + int iter; + int num_blocks_this_iter; + int neg_t_offset; // 0 if !modified, 1 if modified + int has_boundary; +}; + +// ── Forward kernel ──────────────────────────────────────────────────────────── +// Each threadgroup handles multiple (batch, block) pairs via the outer loop. +// Thread layout: 128 threads / threadgroup. +// Inner antidiagonal uses only threads 0..BSIZE-1 (first SIMD group, size 32). +kernel void mi_forward( + device const float* px [[buffer(0)]], // [B][S][T+1 or T] + device const float* py [[buffer(1)]], // [B][S+1][T] + device float* p [[buffer(2)]], // [B][S+1][T+1] (in/out) + device const int* boundary [[buffer(3)]], // [B][4] int32 + device float* ans [[buffer(4)]], // [B] + constant MIFwdParams& params [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint tg_count [[threadgroups_per_grid]], + uint tid [[thread_index_in_threadgroup]]) +{ + // Threadgroup-resident buffers + threadgroup float px_buf[BSIZE][BSIZE]; // 4 KB + threadgroup float py_buf[BSIZE][BSIZE]; // 4 KB + threadgroup float p_buf[BSIZE + 1][BSIZE + 1]; // ~4.4 KB + threadgroup int bnd[4]; + + const int B = params.B, S = params.S, T = params.T; + const int t_o = params.t_offset; + const int iter = params.iter; + const int nblk = params.num_blocks_this_iter; + + for (int bbi = (int)tg_id; bbi < B * nblk; bbi += (int)tg_count) { + int blk = bbi / B, b = bbi % B; + int s_bb = blk * BSIZE; // s_block_begin (before adding s_begin) + int t_bb = (iter - blk) * BSIZE; // t_block_begin (before adding t_begin) + + // ── Load boundary for batch element b ───────────────────────────── + if (tid == 0) { bnd[0] = 0; bnd[1] = 0; bnd[2] = S; bnd[3] = T; } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tid < 4) bnd[tid] = boundary[b * 4 + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + int s_begin = bnd[0], t_begin = bnd[1]; + int s_end = bnd[2], t_end = bnd[3]; + s_bb += s_begin; + t_bb += t_begin; + + int block_S = min(BSIZE, s_end + 1 - s_bb); + int block_T = min(BSIZE, t_end + 1 - t_bb); + if (block_S <= 0 || block_T <= 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + continue; + } + bool is_origin = (s_bb == s_begin && t_bb == t_begin); + + // ── Cooperative load of px_buf and py_buf ───────────────────────── + // 128 threads cover 1024 = BSIZE*BSIZE elements in 8 passes + for (int i = (int)tid; i < BSIZE * BSIZE; i += 128) { + int si = i / BSIZE, ti = i % BSIZE; + int sg = si + s_bb, tg_g = ti + t_bb; + int t_off = tg_g + t_o; + + float pxv = -INFINITY; + if (sg > s_begin && sg <= s_end && t_off >= t_begin && tg_g <= t_end) + pxv = px[b * params.px_stride_b + (sg - 1) * params.px_stride_s + t_off]; + px_buf[si][ti] = pxv; + + float pyv = -INFINITY; + if (tg_g > t_begin && tg_g <= t_end && sg <= s_end) + pyv = py[b * params.py_stride_b + sg * params.py_stride_s + (tg_g - 1)]; + py_buf[si][ti] = pyv; + } + + // ── Load border of p_buf (first column + first row) ─────────────── + // First column (s_in_p = 0..BSIZE, t_in_p = 0): threads 0..BSIZE + if (tid <= (uint)BSIZE) { + int si = (int)tid; + int s = si + s_bb - 1, t = t_bb - 1; + float pv = -INFINITY; + if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) + pv = p[b * params.p_stride_b + s * params.p_stride_s + t]; + p_buf[si][0] = pv; + } + // First row (s_in_p = 0, t_in_p = 0..BSIZE): threads 64..64+BSIZE + // Unsigned cast trick mirrors CUDA: tests both >= 0 and <= BSIZE + { + uint u = tid - 64u; // wraps to large value for tid < 64 + if (u <= (uint)BSIZE) { + int ti = (int)u; + int s = s_bb - 1, t = ti + t_bb - 1; + float pv = -INFINITY; + if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) + pv = p[b * params.p_stride_b + s * params.p_stride_s + t]; + p_buf[0][ti] = pv; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Thread 0: initialize p_buf[1][1] ───────────────────────────── + if (tid == 0) { + p_buf[1][1] = is_origin ? 0.0f : + log_add(p_buf[0][1 + t_o] + px_buf[0][0], + p_buf[1][0] + py_buf[0][0]); + } + + // ── Inner antidiagonal sweep (threads 0..BSIZE-1 only) ──────────── + // simdgroup_barrier syncs the first SIMD group (threads 0..31) + // which is the only group active in this loop. + int s = (int)tid; + for (int i = 1; i < block_S + block_T - 1; ++i) { + simdgroup_barrier(mem_flags::mem_threadgroup); + int t = i - s; + if (s < block_S && t >= 0 && t < block_T) { + p_buf[s + 1][t + 1] = log_add( + p_buf[s][t + 1 + t_o] + px_buf[s][t], + p_buf[s + 1][t] + py_buf[s][t]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Write p_buf results back to global p ────────────────────────── + for (int i = (int)tid; i < BSIZE * BSIZE; i += 128) { + int si = i / BSIZE, ti = i % BSIZE; + if (si < block_S && ti < block_T) { + int sg = si + s_bb, tg_g = ti + t_bb; + p[b * params.p_stride_b + sg * params.p_stride_s + tg_g] = + p_buf[si + 1][ti + 1]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Thread 0: write ans if this is the final (top-right) block ──── + if (tid == 0 && + s_bb + block_S - 1 == s_end && + t_bb + block_T - 1 == t_end) { + ans[b] = p_buf[block_S][block_T]; + } + } +} + +// ── Backward kernel ─────────────────────────────────────────────────────────── +// Mirrors mutual_information_backward_kernel in mutual_information_cuda.cu. +// Inputs: px, py, p (forward output), ans_grad +// Outputs: px_grad, py_grad (p_grad computed internally and accumulated) +kernel void mi_backward( + device const float* px [[buffer(0)]], // [B][S][T+1 or T] + device const float* py [[buffer(1)]], // [B][S+1][T] + device const float* p [[buffer(2)]], // [B][S+1][T+1] + device float* ans_grad [[buffer(3)]], // [B] + device float* p_grad [[buffer(4)]], // [B][S+1][T+1] + device float* px_grad [[buffer(5)]], // [B][S][T+1 or T] + device float* py_grad [[buffer(6)]], // [B][S+1][T] + device const int* boundary [[buffer(7)]], // [B][4] int32 + constant MIBwdParams& params [[buffer(8)]], + uint tg_id [[threadgroup_position_in_grid]], + uint tg_count [[threadgroups_per_grid]], + uint tid [[thread_index_in_threadgroup]]) +{ + // px_buf / py_buf: initially px/py values, then overwritten with term1/term2 + threadgroup float px_buf[BSIZE][BSIZE]; + threadgroup float py_buf[BSIZE][BSIZE]; + // p_buf: (BSIZE+1)×(BSIZE+1), first used for p values, then repurposed for p_grad. + // Unlike forward, indexing is NOT offset by 1; context is on TOP and RIGHT. + threadgroup float p_buf[BSIZE + 1][BSIZE + 1]; + threadgroup int bnd[4]; + + const int B = params.B, S = params.S, T = params.T; + const int neg_t_o = params.neg_t_offset; // 0 if !modified, 1 if modified + const int iter = params.iter; + const int nblk = params.num_blocks_this_iter; + + for (int bbi = (int)tg_id; bbi < B * nblk; bbi += (int)tg_count) { + int blk = bbi / B, b = bbi % B; + int s_bb = blk * BSIZE; + int t_bb = (iter - blk) * BSIZE; + + // ── Boundary ────────────────────────────────────────────────────── + if (tid == 0) { bnd[0] = 0; bnd[1] = 0; bnd[2] = S; bnd[3] = T; } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tid < 4) bnd[tid] = boundary[b * 4 + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + int s_begin = bnd[0], t_begin = bnd[1]; + int s_end = bnd[2], t_end = bnd[3]; + s_bb += s_begin; + t_bb += t_begin; + + int block_S = min(BSIZE, s_end + 1 - s_bb); + int block_T = min(BSIZE, t_end + 1 - t_bb); + if (block_S <= 0 || block_T <= 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + continue; + } + + // ── Load px_buf and py_buf ──────────────────────────────────────── + for (int i = (int)tid; i < BSIZE * BSIZE; i += 128) { + int si = i / BSIZE, ti = i % BSIZE; + int sg = si + s_bb, tg_g = ti + t_bb; + + float pxv = -INFINITY; + if (sg < s_end && tg_g <= t_end) + pxv = px[b * params.px_stride_b + sg * params.px_stride_s + tg_g]; + px_buf[si][ti] = pxv; + + float pyv = -INFINITY; + if (sg <= s_end && tg_g < t_end) + pyv = py[b * params.py_stride_b + sg * params.py_stride_s + tg_g]; + py_buf[si][ti] = pyv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Load p_buf from global p (size (BSIZE+1)×(BSIZE+1)) ───────── + for (int i = (int)tid; i < (BSIZE + 1) * (BSIZE + 1); i += 128) { + int si = i / (BSIZE + 1), ti = i % (BSIZE + 1); + int sg = si + s_bb, tg_g = ti + t_bb; + float pv = 0.0f; + if (sg <= s_end && tg_g <= t_end) { + pv = p[b * params.p_stride_b + sg * params.p_stride_s + tg_g]; + if (pv < -1.0e+30f) pv = -1.0e+30f; + } + p_buf[si][ti] = pv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Compute term1 (xderiv) and term2 (yderiv) in-place ─────────── + // term1[s][t] = safe_exp(p[s][t] + px[s][t] - p[s+1][t-t_offset]) + // = safe_exp(p_buf[s][t] + px_buf[s][t] - p_buf[s+1][t+neg_t_o]) + // term2[s][t] = safe_exp(p[s][t] + py[s][t] - p[s][t+1]) + for (int i = (int)tid; i < BSIZE * BSIZE; i += 128) { + int si = i / BSIZE, ti = i % BSIZE; + float xd = safe_exp_mps(p_buf[si][ti] + px_buf[si][ti] + - p_buf[si + 1][ti + neg_t_o]); + float yd = safe_exp_mps(p_buf[si][ti] + py_buf[si][ti] + - p_buf[si][ti + 1]); + px_buf[si][ti] = xd; + py_buf[si][ti] = yd; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Load p_grad for top+right border of this block ──────────────── + // p_buf[s][block_T] for s in [0..block_S]: threads 0..block_S + if (tid <= (uint)block_S) { + int si = (int)tid, sg = si + s_bb, tg_g = block_T + t_bb; + p_buf[si][block_T] = (sg <= s_end && tg_g <= t_end) + ? p_grad[b * params.p_stride_b + sg * params.p_stride_s + tg_g] + : 0.0f; + } + // p_buf[block_S][t] for t in [0..block_T-1]: use unsigned trick for threads 64..64+block_T-1 + { + uint u = tid - 64u; + if (u < (uint)block_T) { + int ti = (int)u, sg = block_S + s_bb, tg_g = ti + t_bb; + p_buf[block_S][ti] = (sg <= s_end && tg_g <= t_end) + ? p_grad[b * params.p_stride_b + sg * params.p_stride_s + tg_g] + : 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Handle final block: seed p_grad[s_end][t_end] with ans_grad ── + bool is_final = (s_bb + block_S == s_end + 1 && t_bb + block_T == t_end + 1); + int first_iter = block_S + block_T - 2; + if (is_final) { + if (tid == 0) p_buf[block_S - 1][block_T - 1] = ans_grad[b]; + --first_iter; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // ── Inner reverse antidiagonal sweep ───────────────────────────── + int s = (int)tid; + for (int i = first_iter; i >= 0; --i) { + simdgroup_barrier(mem_flags::mem_threadgroup); + int t = i - s; + if (s < block_S && t >= 0 && t < block_T) { + // p_grad[s,t] = p_grad[s+1,t-t_offset]*term1[s,t] + p_grad[s,t+1]*term2[s,t] + // = p_buf[s+1][t+neg_t_o] * px_buf[s][t] + p_buf[s][t+1] * py_buf[s][t] + p_buf[s][t] = p_buf[s + 1][t + neg_t_o] * px_buf[s][t] + + p_buf[s][t + 1] * py_buf[s][t]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ── Write p_grad, px_grad, py_grad ─────────────────────────────── + for (int i = (int)tid; i < BSIZE * BSIZE; i += 128) { + int si = i / BSIZE, ti = i % BSIZE; + int sg = si + s_bb, tg_g = ti + t_bb; + if (tg_g <= t_end && sg <= s_end) { + p_grad[b * params.p_stride_b + sg * params.p_stride_s + tg_g] = + p_buf[si][ti]; + + // px_grad: shape [B][S][T+1] if !modified, [B][S][T] if modified + // condition: sg < s_end && tg_g <= t_end - neg_t_o + if (sg < s_end && tg_g <= t_end - neg_t_o) { + // px_grad[b][sg][tg_g] = p_grad[sg+1][tg_g+neg_t_o] * term1[sg][tg_g] + px_grad[b * params.px_stride_b + sg * params.px_stride_s + tg_g] = + p_buf[si + 1][ti + neg_t_o] * px_buf[si][ti]; + } + + // py_grad: shape [B][S+1][T] + if (tg_g < t_end) { + py_grad[b * params.py_stride_b + sg * params.py_stride_s + tg_g] = + p_buf[si][ti + 1] * py_buf[si][ti]; + } + } + } + + // Thread 0: optionally overwrite ans_grad[b] with recomputed value + // (origin block: p_buf[0][0] == p_grad[s_begin][t_begin]) + if (tid == 0 && s_bb == s_begin && t_bb == t_begin) + ans_grad[b] = p_buf[0][0]; + } +} +)MSL"; + +// ───────────────────────────────────────────────────────────────────────────── +// C++ wrapper: compile kernel once, cache pipeline states +// ───────────────────────────────────────────────────────────────────────────── +namespace { + +struct MIPipelineCache { + // __unsafe_unretained: skip ARC release on exit so we don't send messages + // to Metal objects after PyTorch has torn down the MPS device. + __unsafe_unretained id fwd = nil; + __unsafe_unretained id bwd = nil; +}; + +// Struct layouts must exactly match MSL structs above. +struct MIFwdParams { + int B, S, T; + int px_stride_b, px_stride_s; + int py_stride_b, py_stride_s; + int p_stride_b, p_stride_s; + int iter; + int num_blocks_this_iter; + int t_offset; +}; + +struct MIBwdParams { + int B, S, T; + int px_stride_b, px_stride_s; + int py_stride_b, py_stride_s; + int p_stride_b, p_stride_s; + int iter; + int num_blocks_this_iter; + int neg_t_offset; + int has_boundary; +}; + +MIPipelineCache* GetOrBuildPipelines() { + static dispatch_once_t token; + static MIPipelineCache* cache = nullptr; + dispatch_once(&token, ^{ + cache = new MIPipelineCache(); + id device = at::mps::MPSDevice::getInstance()->device(); + + NSError* err = nil; + MTLCompileOptions* opts = [[MTLCompileOptions alloc] init]; + opts.languageVersion = MTLLanguageVersion2_4; + NSString* src = [NSString stringWithUTF8String:kMIKernelSrc]; + id lib = [device newLibraryWithSource:src options:opts error:&err]; + if (!lib) { + NSLog(@"k2 MPS: Metal compile error: %@", err.localizedDescription); + return; + } + + id fwd_fn = [lib newFunctionWithName:@"mi_forward"]; + if (!fwd_fn) { + NSLog(@"k2 MPS: Metal function 'mi_forward' not found in compiled library"); + return; + } + id bwd_fn = [lib newFunctionWithName:@"mi_backward"]; + if (!bwd_fn) { + NSLog(@"k2 MPS: Metal function 'mi_backward' not found in compiled library"); + return; + } + + cache->fwd = [device newComputePipelineStateWithFunction:fwd_fn error:&err]; + if (!cache->fwd) + NSLog(@"k2 MPS: pipeline (fwd) error: %@", err.localizedDescription); + + cache->bwd = [device newComputePipelineStateWithFunction:bwd_fn error:&err]; + if (!cache->bwd) + NSLog(@"k2 MPS: pipeline (bwd) error: %@", err.localizedDescription); + }); + return cache; +} + +// Encode one kernel dispatch onto PyTorch's MPS command buffer. +void EncodeForward(id pso, + id cmdbuf, + const torch::Tensor& px, + const torch::Tensor& py, + torch::Tensor& p, + const torch::Tensor& boundary, // [B][4] int32, contiguous + torch::Tensor& ans, + const MIFwdParams& params, + int num_threadgroups) { + id enc = [cmdbuf computeCommandEncoder]; + [enc setComputePipelineState:pso]; + + auto bind = [&](int idx, const torch::Tensor& t) { + [enc setBuffer:at::native::mps::getMTLBufferStorage(t) + offset:t.storage_offset() * t.element_size() + atIndex:idx]; + }; + bind(0, px); + bind(1, py); + bind(2, p); + bind(3, boundary); + bind(4, ans); + [enc setBytes:¶ms length:sizeof(params) atIndex:5]; + + MTLSize tg_size = {128, 1, 1}; + MTLSize grid_sz = {(NSUInteger)num_threadgroups, 1, 1}; + [enc dispatchThreadgroups:grid_sz threadsPerThreadgroup:tg_size]; + [enc endEncoding]; +} + +void EncodeBackward(id pso, + id cmdbuf, + const torch::Tensor& px, + const torch::Tensor& py, + const torch::Tensor& p, + torch::Tensor& ans_grad, + torch::Tensor& p_grad, + torch::Tensor& px_grad, + torch::Tensor& py_grad, + const torch::Tensor& boundary, // [B][4] int32 + const MIBwdParams& params, + int num_threadgroups) { + id enc = [cmdbuf computeCommandEncoder]; + [enc setComputePipelineState:pso]; + + auto bind = [&](int idx, const torch::Tensor& t) { + [enc setBuffer:at::native::mps::getMTLBufferStorage(t) + offset:t.storage_offset() * t.element_size() + atIndex:idx]; + }; + bind(0, px); bind(1, py); bind(2, p); + bind(3, ans_grad); + bind(4, p_grad); bind(5, px_grad); bind(6, py_grad); + bind(7, boundary); + [enc setBytes:¶ms length:sizeof(params) atIndex:8]; + + MTLSize tg_size = {128, 1, 1}; + MTLSize grid_sz = {(NSUInteger)num_threadgroups, 1, 1}; + [enc dispatchThreadgroups:grid_sz threadsPerThreadgroup:tg_size]; + [enc endEncoding]; +} + +} // namespace + +namespace k2 { + +// ───────────────────────────────────────────────────────────────────────────── +// MutualInformationMps +// ───────────────────────────────────────────────────────────────────────────── +torch::Tensor MutualInformationMps(torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, + torch::Tensor p) { + TORCH_CHECK(px.dim() == 3 && py.dim() == 3 && p.dim() == 3); + TORCH_CHECK(px.scalar_type() == torch::kFloat, + "MutualInformationMps only supports float32; got ", + px.scalar_type()); + + // Ensure contiguous layout (required for stride-1 t-dimension assumption) + auto px_c = px.contiguous(); + auto py_c = py.contiguous(); + auto p_c = p.contiguous(); + + const int B = px_c.size(0), S = px_c.size(1), T = py_c.size(2); + const bool modified = (px_c.size(2) == (int64_t)T); + + const int BLOCK_SIZE = 32; + const int num_s_blocks = S / BLOCK_SIZE + 1; + const int num_t_blocks = T / BLOCK_SIZE + 1; + const int num_iters = num_s_blocks + num_t_blocks - 1; + + // Expand or create boundary tensor as int32 on MPS + torch::Tensor boundary_i32; + if (opt_boundary.has_value()) { + boundary_i32 = opt_boundary.value().to(torch::kInt32).contiguous(); + } else { + boundary_i32 = torch::tensor({0, 0, S, T}, + torch::TensorOptions().dtype(torch::kInt32).device(px.device())) + .reshape({1, 4}).expand({B, 4}).contiguous(); + } + + auto ans = torch::empty({B}, + torch::TensorOptions().dtype(px.scalar_type()).device(px.device())); + + MIPipelineCache* cache = GetOrBuildPipelines(); + TORCH_CHECK(cache && cache->fwd, "k2 MPS: failed to build MI forward pipeline"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + // Flush any open PyTorch encoder before creating our own. + // COMMIT_AND_CONTINUE ends the current encoder, commits the command buffer, + // and starts a fresh one — so our [cmdbuf computeCommandEncoder] calls below + // won't collide with PyTorch's internal encoder. + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + id cmdbuf = stream->commandBuffer(); + + const int num_threadgroups = 256; // tunable; matches CUDA num_blocks + + for (int iter = 0; iter < num_iters; ++iter) { + int num_blocks_this_iter = std::min(iter + 1, num_s_blocks); + MIFwdParams params{}; + params.B = B; params.S = S; params.T = T; + params.px_stride_b = (int)px_c.stride(0); + params.px_stride_s = (int)px_c.stride(1); + params.py_stride_b = (int)py_c.stride(0); + params.py_stride_s = (int)py_c.stride(1); + params.p_stride_b = (int)p_c.stride(0); + params.p_stride_s = (int)p_c.stride(1); + params.iter = iter; + params.num_blocks_this_iter = num_blocks_this_iter; + params.t_offset = modified ? -1 : 0; + + EncodeForward(cache->fwd, cmdbuf, px_c, py_c, p_c, boundary_i32, ans, + params, num_threadgroups); + } + + // If p was not already contiguous we need to copy the result back + if (!p.is_contiguous()) p.copy_(p_c); + + // Commit asynchronously (PyTorch's MPS stream flushes on next sync point) + stream->synchronize(at::mps::SyncType::NONE); + + return ans; +} + +// ───────────────────────────────────────────────────────────────────────────── +// MutualInformationBackwardMps +// ───────────────────────────────────────────────────────────────────────────── +std::vector MutualInformationBackwardMps( + torch::Tensor px, torch::Tensor py, + torch::optional opt_boundary, + torch::Tensor p, torch::Tensor ans_grad, + bool overwrite_ans_grad) { + + TORCH_CHECK(px.scalar_type() == torch::kFloat, + "MutualInformationBackwardMps only supports float32; got ", + px.scalar_type()); + + auto px_c = px.contiguous(); + auto py_c = py.contiguous(); + auto p_c = p.contiguous(); + auto ans_grad_c = ans_grad.contiguous(); // will be modified in-place if overwrite + + const int B = px_c.size(0), S = px_c.size(1), T = py_c.size(2); + const bool modified = (px_c.size(2) == (int64_t)T); + const bool has_boundary = opt_boundary.has_value(); + + torch::Tensor boundary_i32; + if (has_boundary) { + boundary_i32 = opt_boundary.value().to(torch::kInt32).contiguous(); + } else { + boundary_i32 = torch::tensor({0, 0, S, T}, + torch::TensorOptions().dtype(torch::kInt32).device(px.device())) + .reshape({1, 4}).expand({B, 4}).contiguous(); + } + + auto opts = torch::TensorOptions().dtype(px.scalar_type()).device(px.device()); + int T1 = T + (modified ? 0 : 1); + torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts); + torch::Tensor px_grad = has_boundary ? torch::zeros({B, S, T1}, opts) + : torch::empty({B, S, T1}, opts); + torch::Tensor py_grad = has_boundary ? torch::zeros({B, S + 1, T}, opts) + : torch::empty({B, S + 1, T}, opts); + + const int BLOCK_SIZE = 32; + const int num_s_blocks = S / BLOCK_SIZE + 1; + const int num_t_blocks = T / BLOCK_SIZE + 1; + const int num_iters = num_s_blocks + num_t_blocks - 1; + + MIPipelineCache* cache = GetOrBuildPipelines(); + TORCH_CHECK(cache && cache->bwd, "k2 MPS: failed to build MI backward pipeline"); + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); + id cmdbuf = stream->commandBuffer(); + + const int num_threadgroups = 256; + + for (int iter = num_iters - 1; iter >= 0; --iter) { + int num_blocks_this_iter = std::min(iter + 1, num_s_blocks); + MIBwdParams params{}; + params.B = B; params.S = S; params.T = T; + params.px_stride_b = (int)px_c.stride(0); + params.px_stride_s = (int)px_c.stride(1); + params.py_stride_b = (int)py_c.stride(0); + params.py_stride_s = (int)py_c.stride(1); + params.p_stride_b = (int)p_c.stride(0); + params.p_stride_s = (int)p_c.stride(1); + params.iter = iter; + params.num_blocks_this_iter = num_blocks_this_iter; + params.neg_t_offset = modified ? 1 : 0; + params.has_boundary = has_boundary ? 1 : 0; + + EncodeBackward(cache->bwd, cmdbuf, px_c, py_c, p_c, ans_grad_c, + p_grad, px_grad, py_grad, boundary_i32, params, + num_threadgroups); + } + + if (overwrite_ans_grad && !ans_grad.is_contiguous()) + ans_grad.copy_(ans_grad_c); + + stream->synchronize(at::mps::SyncType::NONE); + + return {px_grad, py_grad}; +} + +} // namespace k2 + +#endif // K2_WITH_MPS diff --git a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu index 415be116a..cd971e3b9 100644 --- a/k2/python/csrc/torch/pruned_ranges_to_lattice.cu +++ b/k2/python/csrc/torch/pruned_ranges_to_lattice.cu @@ -82,9 +82,15 @@ FsaVec PrunedRangesToLattice( context = GetCpuContext(); } else if (ranges.is_cuda()) { context = GetCudaContext(ranges.device().index()); +#ifdef K2_WITH_MPS + } else if (ranges.device().type() == torch::kMPS) { + // MPS: K2_EVAL routes to CPU sequential loop; tensor accessors work via + // unified memory on Apple Silicon. + context = GetMpsContext(); +#endif } else { K2_LOG(FATAL) << "Unsupported device: " << ranges.device() - << "\nOnly CPU and CUDA are verified"; + << "\nOnly CPU and CUDA are supported"; } // "_a" is short for accessor. diff --git a/k2/python/csrc/torch/v2/autograd/swoosh.cu b/k2/python/csrc/torch/v2/autograd/swoosh.cu index 02abcd9b3..a83efee98 100644 --- a/k2/python/csrc/torch/v2/autograd/swoosh.cu +++ b/k2/python/csrc/torch/v2/autograd/swoosh.cu @@ -158,6 +158,31 @@ class SwooshFunction return torch::logaddexp(zero, x - kShift) - kCoeff * x - kOffset; } +#ifdef K2_WITH_MPS + if (context->GetDeviceType() == kMps) { + // K2_EVAL reads raw tensor data via CPU pointers; on MPS this races + // against pending Metal writes. Use ATen ops that run natively on MPS. + // The 8-bit quantised gradient trick is skipped; standard autograd + // handles differentiation via the saved tensors below. + torch::Tensor zero = torch::zeros({1}, x.options()); + torch::Tensor y = torch::logaddexp(zero, x - kShift) + - kCoeff * x - kOffset; + torch::Tensor mask; + if (dropout_prob != 0.0f) { + mask = torch::bernoulli(torch::full_like(y, 1.0f - dropout_prob)); + y = y * mask / (1.0f - dropout_prob); + } + ctx->saved_data["dropout_prob"] = static_cast(dropout_prob); + ctx->saved_data["mps"] = true; + if (dropout_prob != 0.0f) { + ctx->save_for_backward({x, mask}); + } else { + ctx->save_for_backward({x}); + } + return y; + } +#endif + x = x.contiguous(); torch::Tensor y = torch::empty_like(x).contiguous(); @@ -225,6 +250,25 @@ class SwooshFunction static torch::autograd::tensor_list backward( torch::autograd::AutogradContext *ctx, torch::autograd::tensor_list y_grad) { +#ifdef K2_WITH_MPS + if (ctx->saved_data.count("mps") && ctx->saved_data["mps"].toBool()) { + float dropout_prob = ctx->saved_data["dropout_prob"].toDouble(); + auto saved = ctx->get_saved_variables(); + torch::Tensor x = saved[0]; + torch::Tensor out_grad = y_grad[0].contiguous(); + // swoosh'(x) = sigmoid(x - kShift) - kCoeff + torch::Tensor deriv = torch::sigmoid(x - kShift) - kCoeff; + torch::Tensor in_grad; + if (dropout_prob != 0.0f) { + torch::Tensor mask = saved[1]; + in_grad = out_grad * deriv * mask / (1.0f - dropout_prob); + } else { + in_grad = out_grad * deriv; + } + return {in_grad, torch::Tensor()}; + } +#endif + float dropout_prob = ctx->saved_data["dropout_prob"].toDouble(); auto saved = ctx->get_saved_variables(); @@ -283,6 +327,16 @@ torch::Tensor SwooshForward(torch::Tensor x) { static constexpr float kOffset = SwooshConstants::kOffset; x = x.to(torch::kFloat32).contiguous(); + +#ifdef K2_WITH_MPS + if (context->GetDeviceType() == kMps) { + // K2_EVAL reads raw MPS tensor pointers from CPU. Use numerically-stable + // ATen logaddexp which runs natively on MPS without raw pointer access. + torch::Tensor zero = torch::zeros({1}, x.options()); + return torch::logaddexp(zero, x - kShift) - kCoeff * x - kOffset; + } +#endif + const float *x_data = x.data_ptr(); torch::Tensor y = torch::empty_like(x).contiguous(); @@ -327,6 +381,18 @@ std::pair SwooshForwardAndDeriv( static constexpr float kOffset = SwooshConstants::kOffset; x = x.to(torch::kFloat32).contiguous(); + +#ifdef K2_WITH_MPS + if (context->GetDeviceType() == kMps) { + // K2_EVAL reads raw MPS tensor pointers from CPU. Compute via ATen ops + // that run natively on MPS. sigmoid(x-shift) = 1 - 1/(1+exp(x-shift)). + torch::Tensor zero = torch::zeros({1}, x.options()); + torch::Tensor y = torch::logaddexp(zero, x - kShift) - kCoeff * x - kOffset; + torch::Tensor deriv = torch::sigmoid(x - kShift) - kCoeff; + return {y, deriv}; + } +#endif + const float *x_data = x.data_ptr(); torch::Tensor y = torch::empty_like(x).contiguous(); diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index df685e4a2..3c4cf7905 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -375,6 +375,13 @@ RaggedAny RaggedAny::To(torch::Device device) const { return RaggedAny(any.To(GetCpuContext())); } +#ifdef K2_WITH_MPS + if (device.type() == torch::kMPS) { + if (context->GetDeviceType() == kMps) return *this; + return RaggedAny(any.To(GetMpsContext())); + } +#endif + K2_CHECK(device.is_cuda()) << device.str(); int32_t device_index = device.index(); diff --git a/k2/python/csrc/torch/v2/ragged_shape.cu b/k2/python/csrc/torch/v2/ragged_shape.cu index a5cfc064f..54cdc47c4 100644 --- a/k2/python/csrc/torch/v2/ragged_shape.cu +++ b/k2/python/csrc/torch/v2/ragged_shape.cu @@ -73,6 +73,12 @@ void PybindRaggedShape(py::module &m) { if (device.type() == torch::kCPU) return self.To(GetCpuContext()); +#ifdef K2_WITH_MPS + if (device.type() == torch::kMPS) { + return self.To(GetMpsContext()); + } +#endif + K2_CHECK(device.is_cuda()); { DeviceGuard g(GetContext(device)); diff --git a/k2/python/csrc/version.cu b/k2/python/csrc/version.cu index bd7e00744..f615bea16 100644 --- a/k2/python/csrc/version.cu +++ b/k2/python/csrc/version.cu @@ -46,4 +46,5 @@ void PybindVersion(py::module &m) { version.attr("enable_nvtx") = k2::kEnableNvtx; version.attr("disable_debug") = k2::internal::kDisableDebug; version.attr("with_cuda") = k2::kWithCuda; + version.attr("with_mps") = k2::kWithMps; } diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 5d0691c1e..b793342d2 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -119,6 +119,7 @@ from .utils import random_fsa from .utils import random_fsa_vec from _k2.version import with_cuda +from _k2.version import with_mps from _k2 import pruned_ranges_to_lattice from .decode import get_aux_labels @@ -127,3 +128,4 @@ cmake_prefix_path = _Path(__file__).parent / "share" / "cmake" del _Path +__dev_version__ = '1.24.4.dev20260314+cpu.torch2.10.0' diff --git a/k2/python/k2/autograd.py b/k2/python/k2/autograd.py index 5d62b472b..a05ffac16 100644 --- a/k2/python/k2/autograd.py +++ b/k2/python/k2/autograd.py @@ -27,6 +27,22 @@ from .dense_fsa_vec import DenseFsaVec +class _MpsScoresBridge(torch.autograd.Function): + """Move scores MPS→CPU in forward; CPU→MPS in backward. + + This lets k2's CPU-only C++ algorithms operate on the scores while keeping + the gradient connected so that ``fsa.scores.grad`` ends up on MPS. + """ + + @staticmethod + def forward(ctx, mps_scores: torch.Tensor) -> torch.Tensor: + return mps_scores.to('cpu') + + @staticmethod + def backward(ctx, cpu_grad: torch.Tensor) -> torch.Tensor: + return cpu_grad.to('mps') + + class _GetTotScoresFunction(torch.autograd.Function): @staticmethod @@ -94,6 +110,7 @@ def backward(ctx, tot_scores_grad: torch.Tensor use_double_scores = ctx.use_double_scores scores, = ctx.saved_tensors + target_device = fsas.scores.device if log_semiring is False: entering_arcs = fsas._get_entering_arcs(use_double_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) @@ -106,7 +123,7 @@ def backward(ctx, tot_scores_grad: torch.Tensor # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_double_scores, unused_scores - return None, None, None, scores_grad + return None, None, None, scores_grad.to(target_device) else: arc_post = fsas._get_arc_post(use_double_scores, log_semiring) if use_double_scores: @@ -114,7 +131,7 @@ def backward(ctx, tot_scores_grad: torch.Tensor else: bprop_func = _k2.get_tot_scores_float_log_backward scores_grad = bprop_func(fsas.arcs, arc_post, tot_scores_grad) - return None, None, None, scores_grad + return None, None, None, scores_grad.to(target_device) class _GetForwardScoresFunction(torch.autograd.Function): @@ -169,30 +186,55 @@ def backward(ctx, forward_scores_grad: torch.Tensor use_double_scores = ctx.use_double_scores forward_scores, = ctx.saved_tensors - if log_semiring: - entering_arcs = None - else: - entering_arcs = fsas._get_entering_arcs(use_double_scores) - state_batches = fsas._get_state_batches() - leaving_arc_batches = fsas._get_leaving_arc_batches() - bprop_func = (_k2.backprop_get_forward_scores_double if use_double_scores else _k2.backprop_get_forward_scores_float) - scores_grad = bprop_func(fsas.arcs, - state_batches=state_batches, - leaving_arc_batches=leaving_arc_batches, - log_semiring=log_semiring, - entering_arcs=entering_arcs, - forward_scores=forward_scores, - forward_scores_deriv=forward_scores_grad) + if fsas.scores.device.type == 'mps': + # C++ batch-traversal ops use K2_EVAL which reads raw pointers + # and crashes on MPS. Run backward on CPU; the final .to() + # moves the gradient back to MPS for the upstream graph. + cpu_fsas = fsas.to('cpu') + cpu_forward_scores = forward_scores.cpu().contiguous() + # .expand_as ensures we have a true 1D tensor (not an expanded + # scalar with stride 0) when the upstream passes ones_like(fwd). + cpu_fwd_grad = (forward_scores_grad + .expand_as(forward_scores) + .cpu() + .contiguous()) + if log_semiring: + cpu_entering_arcs = None + else: + cpu_entering_arcs = cpu_fsas._get_entering_arcs( + use_double_scores) + scores_grad = bprop_func( + cpu_fsas.arcs, + state_batches=cpu_fsas._get_state_batches(), + leaving_arc_batches=cpu_fsas._get_leaving_arc_batches(), + log_semiring=log_semiring, + entering_arcs=cpu_entering_arcs, + forward_scores=cpu_forward_scores, + forward_scores_deriv=cpu_fwd_grad) + else: + if log_semiring: + entering_arcs = None + else: + entering_arcs = fsas._get_entering_arcs(use_double_scores) + scores_grad = bprop_func( + fsas.arcs, + state_batches=fsas._get_state_batches(), + leaving_arc_batches=fsas._get_leaving_arc_batches(), + log_semiring=log_semiring, + entering_arcs=entering_arcs, + forward_scores=forward_scores.contiguous(), + forward_scores_deriv=forward_scores_grad.expand_as( + forward_scores).contiguous()) return ( None, # fsas None, # log_semiring None, # use_double_scores - scores_grad # unused_scores + scores_grad.to(fsas.scores.device) # unused_scores ) @@ -224,9 +266,17 @@ def forward(ctx, fsas: Fsa, log_semiring: bool, use_double_scores: bool, # that, the backward_fn of backward_scores, which is cached in `fsas`, # would be set to this object, giving `fsas` a reference to this object, # which also has a reference to `fsas`. - backward_scores = fsas._get_backward_scores( - use_double_scores=use_double_scores, - log_semiring=log_semiring).detach() + if fsas.scores.device.type == 'mps': + # C++ batch-traversal ops use K2_EVAL which reads raw pointers + # and crashes on MPS. Run forward on CPU; move result to MPS. + cpu_fsas = fsas.to('cpu') + backward_scores = cpu_fsas._get_backward_scores( + use_double_scores=use_double_scores, + log_semiring=log_semiring).detach().to(fsas.scores.device) + else: + backward_scores = fsas._get_backward_scores( + use_double_scores=use_double_scores, + log_semiring=log_semiring).detach() # NOTE: since `fsas`, `log_semiring` and `use_double_scores` are # not tensors, they are saved as attributes of `ctx`. @@ -245,25 +295,42 @@ def backward(ctx, backward_scores_grad: torch.Tensor use_double_scores = ctx.use_double_scores backward_scores, = ctx.saved_tensors - state_batches = fsas._get_state_batches() - entering_arc_batches = fsas._get_entering_arc_batches() - bprop_func = (_k2.backprop_get_backward_scores_double if use_double_scores else _k2.backprop_get_backward_scores_float) - scores_grad = bprop_func(fsas.arcs, - state_batches=state_batches, - entering_arc_batches=entering_arc_batches, - log_semiring=log_semiring, - backward_scores=backward_scores, - backward_scores_deriv=backward_scores_grad) + if fsas.scores.device.type == 'mps': + cpu_fsas = fsas.to('cpu') + cpu_backward_scores = backward_scores.cpu().contiguous() + cpu_bwd_grad = (backward_scores_grad + .expand_as(backward_scores) + .cpu() + .contiguous()) + state_batches = cpu_fsas._get_state_batches() + entering_arc_batches = cpu_fsas._get_entering_arc_batches() + scores_grad = bprop_func(cpu_fsas.arcs, + state_batches=state_batches, + entering_arc_batches=entering_arc_batches, + log_semiring=log_semiring, + backward_scores=cpu_backward_scores, + backward_scores_deriv=cpu_bwd_grad) + else: + state_batches = fsas._get_state_batches() + entering_arc_batches = fsas._get_entering_arc_batches() + scores_grad = bprop_func(fsas.arcs, + state_batches=state_batches, + entering_arc_batches=entering_arc_batches, + log_semiring=log_semiring, + backward_scores=( + backward_scores.contiguous()), + backward_scores_deriv=backward_scores_grad + .expand_as(backward_scores).contiguous()) return ( None, # fsas None, # log_semiring None, # use_double_scores - scores_grad # unused_scores + scores_grad.to(fsas.scores.device) # unused_scores ) @@ -308,8 +375,15 @@ def forward(ctx, fsas: Fsa, log_semiring: bool, use_double_scores: bool, # if we didn't do that, the backward_fn of the arc_post, which is cached # in `fsas`, would be set to this object, giving `fsas` a reference to # this object, which also has a reference to `fsas`. - arc_post = fsas._get_arc_post(use_double_scores=use_double_scores, - log_semiring=log_semiring).detach() + if fsas.scores.device.type == 'mps': + # C++ batch-traversal ops (K2_EVAL) crash on MPS; bridge to CPU. + cpu_fsas = fsas.to('cpu') + arc_post = cpu_fsas._get_arc_post( + use_double_scores=use_double_scores, + log_semiring=log_semiring).detach().to(fsas.scores.device) + else: + arc_post = fsas._get_arc_post(use_double_scores=use_double_scores, + log_semiring=log_semiring).detach() # NOTE: since `fsas`, `log_semiring` and `use_double_scores` are # not tensors, they are saved as attributes of `ctx`. @@ -331,6 +405,22 @@ def backward( bprop_func = (_k2.backprop_get_arc_post_double if use_double_scores else _k2.backprop_get_arc_post_float) + target_device = fsas.scores.device + if target_device.type == 'mps': + cpu_fsas = fsas.to('cpu') + cpu_arc_scores_grad = arc_post_grad.detach().cpu().clone() + incoming_arcs = cpu_fsas._get_incoming_arcs() + forward_scores_grad, backward_scores_grad = bprop_func( + cpu_fsas.arcs, incoming_arcs, cpu_arc_scores_grad) + return ( + None, # fsas + None, # log_semiring + None, # use_double_scores + cpu_arc_scores_grad.to(target_device), # unused_scores + forward_scores_grad.to(target_device), # forward_scores + backward_scores_grad.to(target_device) # backward_scores + ) + incoming_arcs = fsas._get_incoming_arcs() arc_scores_grad = arc_post_grad.detach().clone() @@ -341,9 +431,9 @@ def backward( None, # fsas None, # log_semiring None, # use_double_scores - arc_scores_grad, # unused_scores - forward_scores_grad, # forward_scores - backward_scores_grad # backward_scores + arc_scores_grad.to(target_device), # unused_scores + forward_scores_grad.to(target_device), # forward_scores + backward_scores_grad.to(target_device) # backward_scores ) @@ -418,9 +508,19 @@ def forward(ctx, ''' assert len(out_fsa) == 1 + # MPS bridge: _k2.intersect_dense_pruned uses K2_EVAL (CPU-only loops) + # and _k2.index_select / index_add are not MPS-safe. Run the entire + # forward on CPU copies and move the result back to MPS at the end. + on_mps = a_fsas.scores.device.type == 'mps' + if on_mps: + a_fsas_fwd = a_fsas.to('cpu') + b_fsas_fwd = b_fsas.to('cpu') + else: + a_fsas_fwd, b_fsas_fwd = a_fsas, b_fsas + ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned( - a_fsas=a_fsas.arcs, - b_fsas=b_fsas.dense_fsa_vec, + a_fsas=a_fsas_fwd.arcs, + b_fsas=b_fsas_fwd.dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, min_active_states=min_active_states, @@ -429,7 +529,7 @@ def forward(ctx, out_fsa[0] = Fsa(ragged_arc) - for name, a_value in a_fsas.named_tensor_attr(include_scores=False): + for name, a_value in a_fsas_fwd.named_tensor_attr(include_scores=False): if isinstance(a_value, torch.Tensor): value = _k2.index_select(a_value, arc_map_a) else: @@ -442,19 +542,21 @@ def forward(ctx, setattr(out_fsa[0], name, value) - for name, a_value in a_fsas.named_non_tensor_attr(): + for name, a_value in a_fsas_fwd.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) + # arc_map_a/b stay on CPU — _k2.index_add is not MPS-safe. ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b + ctx.on_mps = on_mps ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: - num_cols = b_fsas.dense_fsa_vec.scores_dim1() + num_cols = b_fsas_fwd.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols - shape = b_fsas.dense_fsa_vec.shape() + shape = b_fsas_fwd.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) @@ -463,12 +565,15 @@ def forward(ctx, if seqframe_idx_name is not None: if seqframe_idx is None: - num_cols = b_fsas.dense_fsa_vec.scores_dim1() + num_cols = b_fsas_fwd.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) + if on_mps: + out_fsa[0] = out_fsa[0].to('mps') + return out_fsa[0].scores @staticmethod @@ -478,19 +583,36 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ arc_map_a = ctx.arc_map_a arc_map_b = ctx.arc_map_b - grad_a = torch.zeros(a_scores.size(0), - dtype=out_fsa_grad.dtype, - device=a_scores.device, - requires_grad=False) - - grad_b = torch.zeros( - *b_scores.shape, - dtype=out_fsa_grad.dtype, - device=b_scores.device, - requires_grad=False).contiguous() # will use its `view()` later - - _k2.index_add(arc_map_a, out_fsa_grad, grad_a) - _k2.index_add(arc_map_b, out_fsa_grad, grad_b.view(-1)) + if ctx.on_mps: + # arc_map_a/b are on CPU; move grad to CPU, scatter, move back. + # Use out_fsa_grad.dtype (not fsa_grad_cpu.dtype) as the + # authoritative dtype: matches non-MPS branch, supports fp16/bf16. + fsa_grad_cpu = out_fsa_grad.cpu() + grad_a = torch.zeros(a_scores.size(0), + dtype=out_fsa_grad.dtype, + device='cpu', + requires_grad=False) + grad_b = torch.zeros( + *b_scores.shape, + dtype=out_fsa_grad.dtype, + device='cpu', + requires_grad=False).contiguous() + _k2.index_add(arc_map_a, fsa_grad_cpu, grad_a) + _k2.index_add(arc_map_b, fsa_grad_cpu, grad_b.view(-1)) + grad_a = grad_a.to('mps') + grad_b = grad_b.to('mps') + else: + grad_a = torch.zeros(a_scores.size(0), + dtype=out_fsa_grad.dtype, + device=a_scores.device, + requires_grad=False) + grad_b = torch.zeros( + *b_scores.shape, + dtype=out_fsa_grad.dtype, + device=b_scores.device, + requires_grad=False).contiguous() # will use its `view()` later + _k2.index_add(arc_map_a, out_fsa_grad, grad_a) + _k2.index_add(arc_map_b, out_fsa_grad, grad_b.view(-1)) return ( None, # a_fass @@ -567,17 +689,28 @@ def forward(ctx, ''' assert len(out_fsa) == 1 + # MPS bridge: same as _IntersectDensePrunedFunction — run on CPU. + on_mps = a_fsas.scores.device.type == 'mps' + if on_mps: + a_fsas_fwd = a_fsas.to('cpu') + b_fsas_fwd = b_fsas.to('cpu') + a_to_b_map_fwd = (a_to_b_map.cpu() + if a_to_b_map is not None else None) + else: + a_fsas_fwd, b_fsas_fwd = a_fsas, b_fsas + a_to_b_map_fwd = a_to_b_map + ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense( - a_fsas=a_fsas.arcs, - b_fsas=b_fsas.dense_fsa_vec, - a_to_b_map=a_to_b_map, + a_fsas=a_fsas_fwd.arcs, + b_fsas=b_fsas_fwd.dense_fsa_vec, + a_to_b_map=a_to_b_map_fwd, output_beam=output_beam, max_states=max_states, max_arcs=max_arcs) out_fsa[0] = Fsa(ragged_arc) - for name, a_value in a_fsas.named_tensor_attr(include_scores=False): + for name, a_value in a_fsas_fwd.named_tensor_attr(include_scores=False): if isinstance(a_value, torch.Tensor): value = _k2.index_select(a_value, arc_map_a) else: @@ -589,24 +722,25 @@ def forward(ctx, setattr(out_fsa[0], name, value) - for name, a_value in a_fsas.named_non_tensor_attr(): + for name, a_value in a_fsas_fwd.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b + ctx.on_mps = on_mps ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: - num_cols = b_fsas.dense_fsa_vec.scores_dim1() + num_cols = b_fsas_fwd.dense_fsa_vec.scores_dim1() if tuple(map(int, torch.__version__.split(".")[:2])) < (1, 8): seqframe_idx = arc_map_b // num_cols else: seqframe_idx = torch.div( arc_map_b, num_cols, rounding_mode="floor" ) - shape = b_fsas.dense_fsa_vec.shape() + shape = b_fsas_fwd.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) @@ -615,7 +749,7 @@ def forward(ctx, if seqframe_idx_name is not None: if seqframe_idx is None: - num_cols = b_fsas.dense_fsa_vec.scores_dim1() + num_cols = b_fsas_fwd.dense_fsa_vec.scores_dim1() if tuple(map(int, torch.__version__.split(".")[:2])) < (1, 8): seqframe_idx = arc_map_b // num_cols else: @@ -626,6 +760,9 @@ def forward(ctx, assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) + if on_mps: + out_fsa[0] = out_fsa[0].to('mps') + return out_fsa[0].scores @staticmethod @@ -635,19 +772,33 @@ def backward(ctx, out_fsa_grad: torch.Tensor) \ arc_map_a = ctx.arc_map_a arc_map_b = ctx.arc_map_b - grad_a = torch.zeros(a_scores.size(0), - dtype=torch.float32, - device=a_scores.device, - requires_grad=False) - - grad_b = torch.zeros( - *b_scores.shape, - dtype=torch.float32, - device=b_scores.device, - requires_grad=False).contiguous() # will use its `view()` later - - _k2.index_add(arc_map_a, out_fsa_grad, grad_a) - _k2.index_add(arc_map_b, out_fsa_grad, grad_b.view(-1)) + if ctx.on_mps: + fsa_grad_cpu = out_fsa_grad.cpu() + grad_a = torch.zeros(a_scores.size(0), + dtype=out_fsa_grad.dtype, + device='cpu', + requires_grad=False) + grad_b = torch.zeros( + *b_scores.shape, + dtype=out_fsa_grad.dtype, + device='cpu', + requires_grad=False).contiguous() + _k2.index_add(arc_map_a, fsa_grad_cpu, grad_a) + _k2.index_add(arc_map_b, fsa_grad_cpu, grad_b.view(-1)) + grad_a = grad_a.to('mps') + grad_b = grad_b.to('mps') + else: + grad_a = torch.zeros(a_scores.size(0), + dtype=out_fsa_grad.dtype, + device=a_scores.device, + requires_grad=False) + grad_b = torch.zeros( + *b_scores.shape, + dtype=out_fsa_grad.dtype, + device=b_scores.device, + requires_grad=False).contiguous() # will use its `view()` later + _k2.index_add(arc_map_a, out_fsa_grad, grad_a) + _k2.index_add(arc_map_b, out_fsa_grad, grad_b.view(-1)) return ( None, # a_fsas diff --git a/k2/python/k2/dense_fsa_vec.py b/k2/python/k2/dense_fsa_vec.py index 580fc21ce..8b4d3f3e6 100644 --- a/k2/python/k2/dense_fsa_vec.py +++ b/k2/python/k2/dense_fsa_vec.py @@ -209,8 +209,8 @@ def to(self, device: Union[torch.device, str]) -> 'DenseFsaVec': Args: device: An instance of `torch.device` or a string that can be used to - construct a `torch.device`, e.g., 'cpu', 'cuda:0'. - It supports only cpu and cuda devices. + construct a `torch.device`, e.g., 'cpu', 'cuda:0', 'mps'. + Supports cpu, cuda, and mps devices. Returns: Returns a new DenseFsaVec which is this object copied to the given @@ -219,7 +219,7 @@ def to(self, device: Union[torch.device, str]) -> 'DenseFsaVec': if isinstance(device, str): device = torch.device(device) - assert device.type in ('cpu', 'cuda') + assert device.type in ('cpu', 'cuda', 'mps') if device == self.scores.device: return self diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index 965ed5ce2..773fdc7dd 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -547,6 +547,41 @@ def _get_incoming_arcs(self) -> k2.RaggedTensor: self._get_dest_states()) return cache[name] + def _compute_bfs_arc_batches_mps(self): + '''Compute entering-arc BFS batches for MPS without a full CPU copy. + + Exploits the k2 invariant that FSAs are topologically sorted + (arc.src_state_local < arc.dest_state_local), so the local dest-state + index is a valid BFS-level assignment. Only the dest_state column + (4 bytes × num_arcs) crosses the MPS/CPU boundary, versus 16 bytes × + num_arcs + all tensor attributes for a full self.to("cpu") copy. + + Returns: + sorted_arc_ids_mps: int32 MPS tensor of arc indices sorted by + dest_state_local (ascending, stable). + batch_sizes: list[int] — number of arcs per BFS level (index 0 is + always 0 for valid FSAs since no arc enters the start state). + ''' + arcs_vals = self.arcs.values() # [num_arcs, 4] int32 MPS + num_arcs = arcs_vals.shape[0] + if num_arcs == 0: + return (torch.zeros(0, dtype=torch.int32, device='mps'), []) + + # Transfer only the dest_state column to CPU — 4 bytes × num_arcs. + arc_dst_local = arcs_vals[:, 1].contiguous().cpu() # [num_arcs] int32 + + # Sort arc indices by dest_state_local on CPU (fast); move to MPS. + sorted_arc_ids = torch.argsort( + arc_dst_local.to(torch.int64), stable=True).to(torch.int32) + sorted_arc_ids_mps = sorted_arc_ids.to('mps') + + # Count arcs per BFS level (= per dest_state_local value). + num_levels = int(arc_dst_local.max().item()) + 1 + batch_sizes = torch.bincount( + arc_dst_local, minlength=num_levels).tolist() + + return sorted_arc_ids_mps, batch_sizes + def _get_entering_arc_batches(self) -> k2.RaggedTensor: '''Get (and compute if necessary) cached property `self.entering_arc_batches`. @@ -610,17 +645,36 @@ def _get_forward_scores(self, use_double_scores: bool, ('log' if log_semiring else 'tropical') cache = self._cache if name not in cache: - if use_double_scores: - func = _k2.get_forward_scores_double + if self.scores.device.type == 'mps' and not use_double_scores: + # Zero-copy MPS path: transfer only dest_state column to CPU + # for sorting; move sorted arc indices back to MPS. Avoids + # the full self.to('cpu') arc-data copy of Priority 3. + # + # Priority 6 (assoc scan): for single-FSA tropical semiring + # with 4 ≤ N ≤ 128, use a Hillis-Steele prefix scan over + # per-state transition matrices, reducing encoder calls from + # N to ⌈log₂N⌉. Falls back to native sequential internally. + sorted_arc_ids, batch_sizes = \ + self._compute_bfs_arc_batches_mps() + cache[name] = _k2.get_forward_scores_mps_assoc_scan( + self.arcs, sorted_arc_ids, batch_sizes, log_semiring) + elif self.scores.device.type == 'mps' and use_double_scores: + raise NotImplementedError( + '_get_forward_scores with use_double_scores=True is not ' + 'supported on MPS. Use get_forward_scores (differentiable)' + ' or move the FSA to CPU first with fsa.to("cpu").') else: - func = _k2.get_forward_scores_float - cache[name], entering_arcs = func( - self.arcs, - state_batches=self._get_state_batches(), - entering_arc_batches=self._get_entering_arc_batches(), - log_semiring=log_semiring) - if not log_semiring: - cache['entering_arcs'] = entering_arcs + if use_double_scores: + func = _k2.get_forward_scores_double + else: + func = _k2.get_forward_scores_float + cache[name], entering_arcs = func( + self.arcs, + state_batches=self._get_state_batches(), + entering_arc_batches=self._get_entering_arc_batches(), + log_semiring=log_semiring) + if not log_semiring: + cache['entering_arcs'] = entering_arcs return cache[name] def get_forward_scores(self, use_double_scores: bool, @@ -663,6 +717,11 @@ def _get_tot_scores(self, use_double_scores: bool, ('log' if log_semiring else 'tropical') cache = self._cache if name not in cache: + if self.scores.device.type == 'mps': + raise NotImplementedError( + '_get_tot_scores is not supported on MPS. ' + 'Use the differentiable get_tot_scores or move ' + 'the FSA to CPU first with fsa.to("cpu").') if use_double_scores is True: func = _k2.get_tot_scores_double else: @@ -688,6 +747,25 @@ def get_tot_scores(self, use_double_scores: bool, True to use log semiring (log-sum), false to use tropical (i.e. max on scores). ''' + if self.scores.device.type == 'mps': + # k2's C++ algorithms access raw data pointers and only work on + # CPU (or CUDA). Run on CPU and bridge the gradient back to MPS. + cpu_fsa = self.to('cpu') + # _MpsScoresBridge keeps MPS→CPU in forward and CPU→MPS in + # backward, so that mps_fsa.scores.grad ends up on MPS. + # cpu_scores is passed as `unused_scores` to _GetTotScoresFunction + # solely to anchor the backward gradient path; the actual forward + # computation uses cpu_fsa.scores internally. + cpu_scores = k2.autograd._MpsScoresBridge.apply(self.scores) + cpu_tot = k2.autograd._GetTotScoresFunction.apply( + cpu_fsa, log_semiring, use_double_scores, cpu_scores) + # cpu_tot.to() is gradient-safe here: the backward flows entirely + # through cpu_scores (the unused_scores arg above), not through + # cpu_tot itself, so this device transfer does not detach the grad. + # MPS doesn't support float64; downcast to float32 before moving. + result = (cpu_tot.float() + if cpu_tot.dtype == torch.float64 else cpu_tot) + return result.to(self.scores.device) tot_scores = k2.autograd._GetTotScoresFunction.apply( self, log_semiring, use_double_scores, self.scores) return tot_scores @@ -717,6 +795,11 @@ def _get_backward_scores(self, use_double_scores: bool, ('log' if log_semiring else 'tropical') cache = self._cache if name not in cache: + if self.scores.device.type == 'mps': + raise NotImplementedError( + '_get_backward_scores is not supported on MPS. ' + 'Use the differentiable get_backward_scores or move ' + 'the FSA to CPU first with fsa.to("cpu").') if use_double_scores: func = _k2.get_backward_scores_double else: @@ -1093,8 +1176,8 @@ def to(self, device: Union[str, torch.device]) -> 'Fsa': Args: device: An instance of `torch.device` or a string that can be used to - construct a `torch.device`, e.g., 'cpu', 'cuda:0'. - It supports only cpu and cuda devices. + construct a `torch.device`, e.g., 'cpu', 'cuda:0', 'mps'. + Supports cpu, cuda, and mps devices. Returns: Returns a new Fsa which is this object copied to the given device @@ -1104,7 +1187,7 @@ def to(self, device: Union[str, torch.device]) -> 'Fsa': if isinstance(device, str): device = torch.device(device) - assert device.type in ('cpu', 'cuda') + assert device.type in ('cpu', 'cuda', 'mps') if device == self.scores.device: return self @@ -1474,7 +1557,7 @@ def get_aux_label_info(acceptor: Optional[bool], num_aux_labels: Optional[int], if num_aux_labels != 0: aux_label_names.append('aux_labels') for i in range(1, num_aux_labels): - aux_label_names.append(f'aux_labels{i+1}') + aux_label_names.append(f'aux_labels{i + 1}') return num_aux_labels, aux_label_names else: return (0, []) diff --git a/k2/python/k2/version/version.py b/k2/python/k2/version/version.py index 97669ed3a..231abb8a3 100644 --- a/k2/python/k2/version/version.py +++ b/k2/python/k2/version/version.py @@ -55,6 +55,7 @@ def main(): torch_cuda_version = _k2.version.torch_cuda_version enable_nvtx = _k2.version.enable_nvtx with_cuda = _k2.version.with_cuda + with_mps = _k2.version.with_mps disable_debug = _k2.version.disable_debug sync_kernels = os.getenv('K2_SYNC_KERNELS', None) is not None disable_checks = os.getenv('K2_DISABLE_CHECKS', None) is not None @@ -78,10 +79,11 @@ def main(): PyTorch is using Cuda: {torch_cuda_version} NVTX enabled: {enable_nvtx} With CUDA: {with_cuda} +With MPS: {with_mps} Disable debug: {disable_debug} Sync kernels : {sync_kernels} Disable checks: {disable_checks} -Max cpu memory allocate: {max_cpu_mem_allocate} bytes (or {max_cpu_mem_allocate/1024/1024/1024} GB) +Max cpu memory allocate: {max_cpu_mem_allocate} bytes (or {max_cpu_mem_allocate / 1024 / 1024 / 1024} GB) k2 abort: {k2_abort} __file__: {__file__} _k2.__file__: {_k2.__file__} diff --git a/k2/python/tests/test_mps.py b/k2/python/tests/test_mps.py new file mode 100644 index 000000000..bb2faba49 --- /dev/null +++ b/k2/python/tests/test_mps.py @@ -0,0 +1,1226 @@ +"""Tests for k2 Apple Metal (MPS) backend support.""" +import pytest +import torch +import k2 + +mps_available = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS not available" +) + + +@mps_available +class TestMpsContext: + def test_linear_fsa_to_mps(self): + fsa = k2.linear_fsa([1, 2, 3]) + mps_fsa = fsa.to('mps') + assert mps_fsa.device.type == 'mps' + + def test_round_trip(self): + fsa = k2.linear_fsa([1, 2, 3]) + mps_fsa = fsa.to('mps') + cpu_fsa = mps_fsa.to('cpu') + assert (fsa.arcs.values() == cpu_fsa.arcs.values()).all() + + def test_ctc_topo_mps(self): + fsa = k2.ctc_topo(5) + mps_fsa = fsa.to('mps') + assert mps_fsa.device.type == 'mps' + cpu_back = mps_fsa.to('cpu') + assert (fsa.arcs.values() == cpu_back.arcs.values()).all() + + def test_arc_sort_mps(self): + """arc_sort on MPS must produce correctly sorted arcs.""" + # Use an unsorted FSA so we can verify the sort actually ran. + fsa = k2.linear_fsa([3, 1, 2]) + # arc_sort on CPU gives the reference order. + sorted_cpu = k2.arc_sort(fsa) + # Move to MPS and sort there; round-trip back for comparison. + mps_fsa = fsa.to('mps') + sorted_mps = k2.arc_sort(mps_fsa) + assert sorted_mps.device.type == 'mps' + assert ( + sorted_cpu.arcs.values() + == sorted_mps.to('cpu').arcs.values() + ).all() + + def test_ragged_to_mps(self): + # Includes a single-element row to exercise ExclusiveSum/InclusiveSum + # n=1. + ragged = k2.RaggedTensor([[1, 2, 3], [4, 5], [6]]) + mps_ragged = ragged.to(device='mps') + cpu_back = mps_ragged.to(device='cpu') + assert (ragged.values == cpu_back.values).all() + # Verify row structure (row_ids) round-trips correctly. + assert (ragged.shape.row_ids(1) == cpu_back.shape.row_ids(1)).all() + + def test_with_mps_flag(self): + """k2.with_mps must be True when built with MPS support.""" + assert k2.with_mps, "k2 was not built with MPS support" + + +@mps_available +class TestMpsTraining: + def test_mutual_information_backward_mps(self): + B, S, T = 2, 4, 5 + px = torch.randn(B, S, T + 1, device='mps', requires_grad=True) + py = torch.randn(B, S + 1, T, device='mps', requires_grad=True) + mi = k2.mutual_information_recursion(px, py) + mi.sum().backward() + assert px.grad is not None + assert px.grad.device.type == 'mps' + assert py.grad is not None + assert py.grad.device.type == 'mps' + + def test_mutual_information_parity(self): + """MPS result must match CPU to within tolerance.""" + B, S, T = 2, 4, 5 + px_cpu = torch.randn(B, S, T + 1) + py_cpu = torch.randn(B, S + 1, T) + mi_cpu = k2.mutual_information_recursion(px_cpu, py_cpu) + + px_mps = px_cpu.to('mps') + py_mps = py_cpu.to('mps') + mi_mps = k2.mutual_information_recursion(px_mps, py_mps) + + assert torch.allclose(mi_cpu, mi_mps.cpu(), atol=1e-4) + + def test_mutual_information_with_boundary_mps(self): + """mutual_information_recursion with boundary tensor on MPS.""" + B, S, T = 2, 4, 5 + px_cpu = torch.randn(B, S, T + 1) + py_cpu = torch.randn(B, S + 1, T) + boundary = torch.zeros(B, 4, dtype=torch.int64) + boundary[:, 2] = S + boundary[:, 3] = T + + mi_cpu = k2.mutual_information_recursion(px_cpu, py_cpu, + boundary=boundary) + + px_mps = px_cpu.to('mps') + py_mps = py_cpu.to('mps') + boundary_mps = boundary.to('mps') + mi_mps = k2.mutual_information_recursion(px_mps, py_mps, + boundary=boundary_mps) + + assert torch.allclose(mi_cpu, mi_mps.cpu(), atol=1e-4) + + def test_mps_scores_bridge(self): + """_MpsScoresBridge: forward→CPU, backward→MPS with correct grads.""" + mps_t = torch.randn(4, device='mps', requires_grad=True) + cpu_t = k2.autograd._MpsScoresBridge.apply(mps_t) + assert cpu_t.device.type == 'cpu' + cpu_t.sum().backward() + assert mps_t.grad is not None + assert mps_t.grad.device.type == 'mps' + # d(sum)/d(x_i) = 1 for all i — verify gradient magnitudes are correct. + assert torch.allclose(mps_t.grad.cpu(), torch.ones(4)) + + def test_tot_scores_log_backward_mps(self): + """log-semiring get_tot_scores backward: grads on MPS, match CPU.""" + fsa_a = k2.linear_fsa([1, 2]) + fsa_b = k2.linear_fsa([3]) + fsa_vec = k2.create_fsa_vec([fsa_a, fsa_b]) + + # CPU reference + cpu_fsa = fsa_vec.clone() + cpu_fsa.scores.requires_grad_(True) + cpu_tot = cpu_fsa.get_tot_scores(log_semiring=True, + use_double_scores=False) + cpu_tot.sum().backward() + + # MPS under test + mps_fsa = fsa_vec.to('mps') + mps_fsa.scores.requires_grad_(True) + mps_tot = mps_fsa.get_tot_scores(log_semiring=True, + use_double_scores=False) + mps_tot.sum().backward() + + assert mps_fsa.scores.grad is not None + assert mps_fsa.scores.grad.device.type == 'mps' + assert torch.allclose(cpu_fsa.scores.grad, + mps_fsa.scores.grad.cpu(), atol=1e-5) + + def test_tot_scores_tropical_backward_mps(self): + """tropical-semiring get_tot_scores backward: grads land on MPS.""" + fsa_a = k2.linear_fsa([1, 2]) + fsa_b = k2.linear_fsa([3]) + fsa_vec = k2.create_fsa_vec([fsa_a, fsa_b]) + + cpu_fsa = fsa_vec.clone() + cpu_fsa.scores.requires_grad_(True) + cpu_tot = cpu_fsa.get_tot_scores(log_semiring=False, + use_double_scores=False) + cpu_tot.sum().backward() + + mps_fsa = fsa_vec.to('mps') + mps_fsa.scores.requires_grad_(True) + mps_tot = mps_fsa.get_tot_scores(log_semiring=False, + use_double_scores=False) + mps_tot.sum().backward() + + assert mps_fsa.scores.grad is not None + assert mps_fsa.scores.grad.device.type == 'mps' + assert torch.allclose(cpu_fsa.scores.grad, + mps_fsa.scores.grad.cpu(), atol=1e-5) + + def test_tot_scores_double_mps(self): + """use_double_scores=True on MPS: result as float32 (no float64).""" + fsa = k2.linear_fsa([1, 2]) + fsa_vec = k2.create_fsa_vec([fsa]) + + cpu_tot = fsa_vec.get_tot_scores(log_semiring=True, + use_double_scores=True) + + mps_fsa = fsa_vec.to('mps') + mps_fsa.scores.requires_grad_(True) + mps_tot = mps_fsa.get_tot_scores(log_semiring=True, + use_double_scores=True) + + # MPS has no float64 support; result is downcast to float32. + assert mps_tot.device.type == 'mps' + assert mps_tot.dtype == torch.float32 + assert torch.allclose(cpu_tot.float(), mps_tot.cpu(), atol=1e-5) + + # Gradients must flow back to MPS scores even with the float64 downcast. + mps_tot.sum().backward() + assert mps_fsa.scores.grad is not None + assert mps_fsa.scores.grad.device.type == 'mps' + + def test_tot_scores_nonunit_gradient_mps(self): + """Non-unit incoming gradient scales score grads correctly on MPS.""" + fsa_a = k2.linear_fsa([1, 2]) + fsa_b = k2.linear_fsa([3]) + fsa_vec = k2.create_fsa_vec([fsa_a, fsa_b]) + + # CPU reference with weighted upstream gradient. + cpu_fsa = fsa_vec.clone() + cpu_fsa.scores.requires_grad_(True) + cpu_tot = cpu_fsa.get_tot_scores(log_semiring=True, + use_double_scores=False) + upstream = torch.tensor([2.0, 0.5]) + cpu_tot.backward(upstream) + + # MPS under test. + mps_fsa = fsa_vec.to('mps') + mps_fsa.scores.requires_grad_(True) + mps_tot = mps_fsa.get_tot_scores(log_semiring=True, + use_double_scores=False) + mps_tot.backward(upstream.to('mps')) + + assert mps_fsa.scores.grad is not None + assert mps_fsa.scores.grad.device.type == 'mps' + assert torch.allclose(cpu_fsa.scores.grad, + mps_fsa.scores.grad.cpu(), atol=1e-5) + + +@mps_available +class TestMpsForwardScores: + """Tests for native Metal GetForwardScores (Priority 3).""" + + def _make_fsa_vec(self, fsa_str: str, device: str): + fsa = k2.Fsa.from_str(fsa_str) + return k2.create_fsa_vec([fsa]).to(device) + + def test_forward_scores_log_parity(self): + """Metal log-semiring forward scores must match CPU.""" + s = ''' + 0 1 0 0.1 + 0 1 1 0.2 + 1 2 -1 0.3 + 2 + ''' + fsa_mps = self._make_fsa_vec(s, 'mps') + fsa_cpu = self._make_fsa_vec(s, 'cpu') + fwd_mps = fsa_mps._get_forward_scores(use_double_scores=False, + log_semiring=True) + fwd_cpu = fsa_cpu._get_forward_scores(use_double_scores=False, + log_semiring=True) + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_forward_scores_tropical_parity(self): + """Metal tropical-semiring forward scores must match CPU.""" + s = ''' + 0 1 0 0.1 + 0 1 1 0.2 + 1 2 -1 0.3 + 2 + ''' + fsa_mps = self._make_fsa_vec(s, 'mps') + fsa_cpu = self._make_fsa_vec(s, 'cpu') + fwd_mps = fsa_mps._get_forward_scores(use_double_scores=False, + log_semiring=False) + fwd_cpu = fsa_cpu._get_forward_scores(use_double_scores=False, + log_semiring=False) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_forward_scores_differentiable(self): + """Differentiable get_forward_scores must compute correct gradients.""" + s = ''' + 0 1 0 0.1 + 0 1 1 0.2 + 1 2 -1 0.3 + 2 + ''' + fsa_mps = self._make_fsa_vec(s, 'mps') + fsa_mps.scores.requires_grad_(True) + fsa_cpu = self._make_fsa_vec(s, 'cpu') + fsa_cpu.scores.requires_grad_(True) + + fwd_mps = fsa_mps.get_forward_scores(use_double_scores=False, + log_semiring=True) + fwd_cpu = fsa_cpu.get_forward_scores(use_double_scores=False, + log_semiring=True) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + scale = torch.arange(fwd_mps.numel()).float() + (scale.to('mps') * fwd_mps).sum().backward() + (scale * fwd_cpu).sum().backward() + assert torch.allclose(fsa_mps.scores.grad.cpu(), + fsa_cpu.scores.grad, atol=1e-5) + + def test_forward_scores_multi_fsa(self): + """Metal kernel must handle batched FSAs correctly.""" + s1 = ''' + 0 1 0 0.5 + 1 2 -1 1.0 + 2 + ''' + s2 = ''' + 0 1 0 0.1 + 0 1 1 0.2 + 1 2 -1 0.3 + 2 + ''' + fsa_vec_mps = k2.create_fsa_vec( + [k2.Fsa.from_str(s1), k2.Fsa.from_str(s2)]).to('mps') + fsa_vec_cpu = k2.create_fsa_vec( + [k2.Fsa.from_str(s1), k2.Fsa.from_str(s2)]) + fwd_mps = fsa_vec_mps._get_forward_scores(use_double_scores=False, + log_semiring=True) + fwd_cpu = fsa_vec_cpu._get_forward_scores(use_double_scores=False, + log_semiring=True) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + +@mps_available +class TestMpsIntersectDense: + """Tests for intersect_dense and intersect_dense_pruned on MPS.""" + + # Simple decoding graph shared across tests. + _FSA_STR = ''' + 0 1 1 1.0 + 1 2 2 2.0 + 2 3 -1 3.0 + 3 + ''' + + def _make_fsa_vec(self, device): + # Build on CPU first — create_fsa_vec calls Fsa() which accesses + # `properties` (K2_EVAL), which crashes when the FSA is on MPS. + fsa = k2.Fsa.from_str(self._FSA_STR) + fsa_vec = k2.create_fsa_vec([fsa]).to(device) + fsa_vec.scores.requires_grad_(True) + return fsa_vec + + def _make_dense(self, device): + # DenseFsaVec.__init__ calls _k2.DenseFsaVec(scores, row_splits) where + # row_splits is read via data_ptr() — crashes on MPS. Build on CPU + # then move to device; set requires_grad on the device-resident scores. + log_prob_cpu = torch.tensor( + [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]], dtype=torch.float32) + segs = torch.tensor([[0, 0, 2]], dtype=torch.int32) + dense = k2.DenseFsaVec(log_prob_cpu, segs) + if device != 'cpu': + dense = dense.to(device) + dense.scores.requires_grad_(True) + return dense + + def test_intersect_dense_forward_parity(self): + """intersect_dense on MPS must produce scores matching CPU.""" + fsa_cpu = self._make_fsa_vec('cpu') + dense_cpu = self._make_dense('cpu') + out_cpu = k2.intersect_dense(fsa_cpu, dense_cpu, output_beam=100000) + + fsa_mps = self._make_fsa_vec('mps') + dense_mps = self._make_dense('mps') + out_mps = k2.intersect_dense(fsa_mps, dense_mps, output_beam=100000) + + assert out_mps.device.type == 'mps' + assert torch.allclose(out_mps.scores.cpu(), out_cpu.scores, atol=1e-5) + + def test_intersect_dense_backward_parity(self): + """intersect_dense backward on MPS: grads land on MPS, match CPU.""" + fsa_cpu = self._make_fsa_vec('cpu') + dense_cpu = self._make_dense('cpu') + out_cpu = k2.intersect_dense(fsa_cpu, dense_cpu, output_beam=100000) + out_cpu.get_tot_scores(log_semiring=False, + use_double_scores=False).sum().backward() + + fsa_mps = self._make_fsa_vec('mps') + dense_mps = self._make_dense('mps') + out_mps = k2.intersect_dense(fsa_mps, dense_mps, output_beam=100000) + out_mps.get_tot_scores(log_semiring=False, + use_double_scores=False).sum().backward() + + # Graph-arc score gradients. + assert fsa_mps.scores.grad is not None + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose(fsa_mps.scores.grad.cpu(), + fsa_cpu.scores.grad, atol=1e-5) + # Acoustic (b_fsa) score gradients. + assert dense_mps.scores.grad is not None + assert dense_mps.scores.grad.device.type == 'mps' + assert torch.allclose(dense_mps.scores.grad.cpu(), + dense_cpu.scores.grad, atol=1e-5) + + def test_intersect_dense_seqframe_attr(self): + """seqframe_idx_name and frame_idx_name attributes must work on MPS.""" + fsa_mps = self._make_fsa_vec('mps') + dense_mps = self._make_dense('mps') + out_mps = k2.intersect_dense(fsa_mps, dense_mps, output_beam=100000, + seqframe_idx_name='seqframe', + frame_idx_name='frame') + assert hasattr(out_mps, 'seqframe') + assert hasattr(out_mps, 'frame') + # Verify against CPU reference. + fsa_cpu = self._make_fsa_vec('cpu') + dense_cpu = self._make_dense('cpu') + out_cpu = k2.intersect_dense(fsa_cpu, dense_cpu, output_beam=100000, + seqframe_idx_name='seqframe', + frame_idx_name='frame') + assert torch.equal(out_mps.seqframe.cpu(), out_cpu.seqframe) + assert torch.equal(out_mps.frame.cpu(), out_cpu.frame) + + def test_intersect_dense_pruned_forward_parity(self): + """intersect_dense_pruned on MPS forward scores must match CPU.""" + fsa_cpu = self._make_fsa_vec('cpu') + dense_cpu = self._make_dense('cpu') + out_cpu = k2.intersect_dense_pruned(fsa_cpu, dense_cpu, + search_beam=100000, + output_beam=100000, + min_active_states=0, + max_active_states=10000) + + fsa_mps = self._make_fsa_vec('mps') + dense_mps = self._make_dense('mps') + out_mps = k2.intersect_dense_pruned(fsa_mps, dense_mps, + search_beam=100000, + output_beam=100000, + min_active_states=0, + max_active_states=10000) + + assert out_mps.device.type == 'mps' + assert torch.allclose(out_mps.scores.cpu(), out_cpu.scores, atol=1e-5) + + def test_intersect_dense_pruned_backward_parity(self): + """intersect_dense_pruned backward on MPS: grads land on MPS.""" + fsa_cpu = self._make_fsa_vec('cpu') + dense_cpu = self._make_dense('cpu') + out_cpu = k2.intersect_dense_pruned(fsa_cpu, dense_cpu, + search_beam=100000, + output_beam=100000, + min_active_states=0, + max_active_states=10000) + out_cpu.get_tot_scores(log_semiring=False, + use_double_scores=False).sum().backward() + + fsa_mps = self._make_fsa_vec('mps') + dense_mps = self._make_dense('mps') + out_mps = k2.intersect_dense_pruned(fsa_mps, dense_mps, + search_beam=100000, + output_beam=100000, + min_active_states=0, + max_active_states=10000) + out_mps.get_tot_scores(log_semiring=False, + use_double_scores=False).sum().backward() + + assert fsa_mps.scores.grad is not None + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose(fsa_mps.scores.grad.cpu(), + fsa_cpu.scores.grad, atol=1e-5) + assert dense_mps.scores.grad is not None + assert dense_mps.scores.grad.device.type == 'mps' + assert torch.allclose(dense_mps.scores.grad.cpu(), + dense_cpu.scores.grad, atol=1e-5) + + +@mps_available +class TestMpsAssocScan: + """Tests for Priority 6: Hillis-Steele associative-scan forward scores.""" + + # 4-state linear-chain FSA: 3 arcs, 4 states → assoc-scan threshold met. + _LINEAR_STR = ''' + 0 1 1 0.5 + 1 2 2 1.0 + 2 3 -1 1.5 + 3 + ''' + + # 8-state FSA with branching to exercise prefix products more thoroughly. + _BRANCHING_STR = ''' + 0 1 1 1.0 + 0 2 2 2.0 + 1 3 3 0.5 + 2 3 3 1.5 + 3 4 4 0.25 + 4 5 5 0.75 + 5 6 6 0.5 + 6 7 -1 1.0 + 7 + ''' + + def _make_fsa(self, fsa_str, device='cpu'): + fsa = k2.Fsa.from_str(fsa_str.strip()) + return k2.create_fsa_vec([fsa]).to(device) + + def test_assoc_scan_linear_tropical_parity(self): + """Single-FSA tropical forward scores via assoc scan match CPU.""" + fsa_cpu = self._make_fsa(self._LINEAR_STR) + fsa_mps = self._make_fsa(self._LINEAR_STR, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_assoc_scan_linear_log_parity(self): + """Log-semiring falls back to native path; results still match CPU.""" + fsa_cpu = self._make_fsa(self._LINEAR_STR) + fsa_mps = self._make_fsa(self._LINEAR_STR, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=True) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=True) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_assoc_scan_branching_tropical_parity(self): + """Branching 8-state FSA tropical forward scores match CPU.""" + fsa_cpu = self._make_fsa(self._BRANCHING_STR) + fsa_mps = self._make_fsa(self._BRANCHING_STR, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_assoc_scan_tot_scores_parity(self): + """Total (best-path) scores via assoc-scan forward path match CPU. + + Uses the differentiable get_tot_scores which bridges MPS→CPU for the + final C++ score extraction (the non-differentiable _get_tot_scores is + not MPS-safe, as its C++ path reads MPS arcs via K2_EVAL). + """ + fsa_cpu = self._make_fsa(self._BRANCHING_STR) + fsa_mps = self._make_fsa(self._BRANCHING_STR, 'mps') + + tot_cpu = fsa_cpu.get_tot_scores( + use_double_scores=False, log_semiring=False) + tot_mps = fsa_mps.get_tot_scores( + use_double_scores=False, log_semiring=False) + + assert tot_mps.device.type == 'mps' + assert torch.allclose(tot_mps.cpu(), tot_cpu, atol=1e-5) + + def test_assoc_scan_large_fallback(self): + """FSA with >128 states falls back to native sequential path.""" + # Build a long linear chain with 200 states: exceeds N_MAX=128. + # The last arc must use label -1 (final/epsilon in k2 convention). + arcs = [] + for i in range(198): + arcs.append(f'{i} {i + 1} {i % 100 + 1} 0.01') + arcs.append('198 199 -1 0.01') + arcs.append('199') + fsa_str = '\n'.join(arcs) + fsa_cpu = self._make_fsa(fsa_str) + fsa_mps = self._make_fsa(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-4) + + +@mps_available +class TestMpsNumericalParity: + """Verify MPS results match CPU within tolerance.""" + + def test_linear_fsa_scores_parity(self): + """Round-tripping non-trivial scores through MPS preserves values.""" + fsa = k2.linear_fsa([1, 2, 3]) + fsa_vec = k2.create_fsa_vec([fsa]) + # Assign non-trivial scores so a no-op copy would be detected. + fsa_vec.scores = torch.tensor([1.5, -0.5, 2.0, 0.0]) + mps_fsa = fsa_vec.to('mps') + assert torch.allclose(fsa_vec.scores, + mps_fsa.scores.cpu(), atol=1e-6) + + def test_dense_fsa_vec_mps(self): + """DenseFsaVec.to('mps') round-trip must preserve score values.""" + # Build a small DenseFsaVec on CPU. + T, num_classes = 5, 4 + log_probs = torch.randn(1, T, num_classes) + supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) + dense = k2.DenseFsaVec(log_probs, supervision_segments) + mps_dense = dense.to('mps') + assert mps_dense.scores.device.type == 'mps' + # Round-trip back and verify values. + cpu_back = mps_dense.to('cpu') + assert torch.allclose(dense.scores, cpu_back.scores, atol=1e-6) + +# ============================================================================= +# PR Audit — Extended Test Suite +# ============================================================================= + + +def _make_fsa_vec(fsa_str, device='cpu'): + """Helper: build FsaVec from multi-line FSA string, move to device.""" + fsa = k2.Fsa.from_str(fsa_str.strip()) + return k2.create_fsa_vec([fsa]).to(device) + + +@mps_available +class TestMpsEdgeCases: + """Edge-case tests: empty FSA, unreachable states, guard paths.""" + + def test_empty_arcs_forward_scores(self): + """FSA with only start/accept state (0 arcs) returns correct scores.""" + # A single-arc FSA with just a final arc to the accept state. + # This has 2 states: state 0 (start) and state 1 (accept). + fsa_str = '0 1 -1 0.0\n1' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_unreachable_state_forward_scores(self): + """State with no entering arcs gets -inf forward score.""" + # State 1 is unreachable: arcs go 0→2 and 2→3. State 1 is dangling. + # k2 requires valid topological structure; use isolated final state. + fsa_str = ''' + 0 2 1 1.0 + 0 2 2 2.0 + 2 3 -1 0.5 + 3 + ''' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_forward_scores_double_mps_raises(self): + """_get_forward_scores with use_double_scores=True raises error.""" + fsa_mps = _make_fsa_vec('0 1 -1 1.0\n1', 'mps') + with pytest.raises(NotImplementedError, match='use_double_scores'): + fsa_mps._get_forward_scores( + use_double_scores=True, log_semiring=False) + + def test_backward_scores_mps_raises(self): + """_get_backward_scores on MPS raises NotImplementedError.""" + fsa_mps = _make_fsa_vec('0 1 -1 1.0\n1', 'mps') + with pytest.raises(NotImplementedError, match='_get_backward_scores'): + fsa_mps._get_backward_scores( + use_double_scores=False, log_semiring=False) + + def test_single_path_forward_scores(self): + """Single-path chain: MPS forward scores equal manual computation.""" + # 0→1 (w=1.0) → 2 (w=2.0) → 3 (w=3.0, final) + fsa_str = '0 1 1 1.0\n1 2 2 2.0\n2 3 -1 3.0\n3' + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + fwd = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + # Expected: state 0=0, state 1=1, state 2=3, state 3=6 + expected = torch.tensor([0.0, 1.0, 3.0, 6.0]) + assert torch.allclose(fwd.cpu(), expected, atol=1e-5) + + def test_parallel_arcs_max_score(self): + """Parallel arcs same src→dst: tropical forward takes maximum.""" + # Two arcs from 0→1: weights 3.0 and 5.0. Max wins. + fsa_str = '0 1 1 3.0\n0 1 2 5.0\n1 2 -1 1.0\n2' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + # State 1 score = max(3.0, 5.0) = 5.0 + assert abs(fwd_mps[1].item() - 5.0) < 1e-5 + + def test_parallel_arcs_log_semiring(self): + """Multiple arcs with same src→dst: log-semiring sums contributions.""" + fsa_str = '0 1 1 1.0\n0 1 2 2.0\n1 2 -1 0.0\n2' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=True) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=True) + + assert fwd_mps.device.type == 'mps' + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + +@mps_available +class TestMpsAssocScanBoundaries: + """Boundary conditions for the Priority-6 Hillis-Steele associative scan.""" + + def _chain_fsa(self, n_states, device='cpu'): + """Build a linear-chain FsaVec with n_states states.""" + arcs = [f'{i} {i + 1} {i + 1} {float(i + 1) * 0.1:.1f}' + for i in range(n_states - 2)] + arcs.append( + f'{n_states - 2} {n_states - 1} -1 ' + f'{float(n_states - 1) * 0.1:.1f}' + ) + arcs.append(str(n_states - 1)) + fsa = k2.Fsa.from_str('\n'.join(arcs).strip()) + return k2.create_fsa_vec([fsa]).to(device) + + def _parity(self, n_states): + """Assert MPS tropical forward scores match CPU for chain FSA.""" + fsa_cpu = self._chain_fsa(n_states) + fsa_mps = self._chain_fsa(n_states, 'mps') + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-4) + + def test_n_at_lower_bound(self): + """N=4: minimum threshold — uses assoc scan, not native sequential.""" + self._parity(4) + + def test_n_just_above_lower(self): + """N=5: T_pow2=8, exercises identity-padding in Hillis-Steele.""" + self._parity(5) + + def test_n_nonpower_of_two_small(self): + """N=7: T_pow2=8, two extra identity padding matrices.""" + self._parity(7) + + def test_n_at_power_of_two_mid(self): + """N=16: T_pow2=16, no padding needed.""" + self._parity(16) + + def test_n_at_upper_bound(self): + """N=128: maximum threshold — still uses assoc scan.""" + self._parity(128) + + def test_n_just_above_upper(self): + """N=129: above threshold — falls back to native sequential.""" + self._parity(129) + + def test_diamond_topology_assoc_scan(self): + """Diamond: two paths to same dest; assoc scan atomic-max is correct.""" + # 0 → 1 (w=1.0), 0 → 2 (w=2.0), 1 → 3 (w=1.5), 2 → 3 (w=0.5), 3 final + fsa_str = ''' + 0 1 1 1.0 + 0 2 2 2.0 + 1 3 3 1.5 + 2 3 3 0.5 + 3 4 -1 0.0 + 4 + ''' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + # Best path: 0→2→3 = 2.0+0.5=2.5; 0→1→3 = 1.0+1.5=2.5 (tie → max = 2.5) + assert abs(fwd_mps[3].item() - 2.5) < 1e-5 + + def test_multi_arc_single_dest_assoc_scan(self): + """Multiple arcs into same destination in assoc scan build_level.""" + # State 2: arcs from 0 (w=1.0), 1a (w=3.0), 1b (w=2.0). + # The build_level kernel's atomic-max must keep 3.0. + fsa_str = ''' + 0 1 1 1.0 + 0 2 2 1.0 + 1 2 3 3.0 + 1 2 4 2.0 + 2 3 -1 0.0 + 3 + ''' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=False) + + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + def test_assoc_scan_log_semiring_falls_back(self): + """Log semiring falls back to native; result still matches CPU.""" + fsa_str = '0 1 1 0.5\n1 2 -1 0.5\n2' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + fwd_cpu = fsa_cpu._get_forward_scores( + use_double_scores=False, log_semiring=True) + fwd_mps = fsa_mps._get_forward_scores( + use_double_scores=False, log_semiring=True) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-5) + + +@mps_available +class TestMpsArcPost: + """Tests for arc posteriors (differentiable get_arc_post) on MPS.""" + + _FSA_STR = ''' + 0 1 1 1.0 + 0 1 2 2.0 + 1 2 -1 0.5 + 2 + ''' + + def _make(self, device='cpu'): + fsa = k2.Fsa.from_str(self._FSA_STR.strip()) + fsa_vec = k2.create_fsa_vec([fsa]).to(device) + fsa_vec.scores.requires_grad_(True) + return fsa_vec + + def test_arc_post_tropical_parity(self): + """Tropical arc posteriors on MPS match CPU.""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + post_cpu = fsa_cpu.get_arc_post( + use_double_scores=False, log_semiring=False) + post_mps = fsa_mps.get_arc_post( + use_double_scores=False, log_semiring=False) + + assert post_mps.device.type == 'mps' + assert torch.allclose(post_mps.cpu(), post_cpu, atol=1e-5) + + def test_arc_post_log_parity(self): + """Log-semiring arc posteriors on MPS match CPU.""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + post_cpu = fsa_cpu.get_arc_post( + use_double_scores=False, log_semiring=True) + post_mps = fsa_mps.get_arc_post( + use_double_scores=False, log_semiring=True) + + assert post_mps.device.type == 'mps' + assert torch.allclose(post_mps.cpu(), post_cpu, atol=1e-5) + + def test_arc_post_tropical_gradient(self): + """Gradients flow correctly through get_arc_post on MPS (tropical).""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + post_cpu = fsa_cpu.get_arc_post( + use_double_scores=False, log_semiring=False) + post_mps = fsa_mps.get_arc_post( + use_double_scores=False, log_semiring=False) + + post_cpu.sum().backward() + post_mps.sum().backward() + + assert fsa_mps.scores.grad is not None + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + + def test_arc_post_log_gradient(self): + """Gradients flow correctly through get_arc_post on MPS (log).""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + post_cpu = fsa_cpu.get_arc_post( + use_double_scores=False, log_semiring=True) + post_mps = fsa_mps.get_arc_post( + use_double_scores=False, log_semiring=True) + + post_cpu.sum().backward() + post_mps.sum().backward() + + assert fsa_mps.scores.grad is not None + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + + +@mps_available +class TestMpsGetForwardScoresDifferentiable: + """Differentiable get_forward_scores tests on MPS.""" + + _FSA_STR = ''' + 0 1 1 0.5 + 0 2 2 1.0 + 1 3 3 0.5 + 2 3 3 1.0 + 3 4 -1 0.0 + 4 + ''' + + def _make(self, device='cpu'): + fsa = k2.Fsa.from_str(self._FSA_STR.strip()) + fsa_vec = k2.create_fsa_vec([fsa]).to(device) + fsa_vec.scores.requires_grad_(True) + return fsa_vec + + def test_get_forward_scores_tropical_gradient(self): + """Differentiable tropical forward scores: grad on MPS matches CPU.""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + fwd_cpu = fsa_cpu.get_forward_scores( + use_double_scores=False, log_semiring=False) + fwd_mps = fsa_mps.get_forward_scores( + use_double_scores=False, log_semiring=False) + + fwd_cpu.sum().backward() + fwd_mps.sum().backward() + + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + + def test_get_forward_scores_log_gradient(self): + """Differentiable log-semiring forward scores: MPS grad matches CPU.""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + fwd_cpu = fsa_cpu.get_forward_scores( + use_double_scores=False, log_semiring=True) + fwd_mps = fsa_mps.get_forward_scores( + use_double_scores=False, log_semiring=True) + + fwd_cpu.sum().backward() + fwd_mps.sum().backward() + + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + + def test_get_forward_scores_nonunit_gradient(self): + """Non-unit upstream gradient correctly scales MPS grad.""" + fsa_cpu = self._make() + fsa_mps = self._make('mps') + + fwd_cpu = fsa_cpu.get_forward_scores( + use_double_scores=False, log_semiring=True) + fwd_mps = fsa_mps.get_forward_scores( + use_double_scores=False, log_semiring=True) + + upstream = torch.ones_like(fwd_cpu) * 2.0 + fwd_cpu.backward(upstream) + fwd_mps.backward(upstream.to('mps')) + + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + + +@mps_available +class TestMpsIntersectDenseExtended: + """Extended IntersectDense/IntersectDensePruned tests (Priority 5).""" + + _FSA_STR = ''' + 0 1 1 0.0 + 1 2 2 0.0 + 2 3 -1 0.0 + 3 + ''' + + def _make_fsa(self, device='cpu'): + fsa = k2.Fsa.from_str(self._FSA_STR.strip()) + fsa_vec = k2.create_fsa_vec([fsa]).to(device) + fsa_vec.scores.requires_grad_(True) + return fsa_vec + + def _make_dense(self, n_utterances, device='cpu', seed=0): + """Build a DenseFsaVec with n_utterances independent segments. + + Uses a fixed seed so that paired CPU/MPS calls with the same seed + produce identical inputs, enabling genuine parity checks. + """ + torch.manual_seed(seed) + T, V = 3, 3 + log_probs = torch.randn(n_utterances, T, V) + segs = torch.tensor( + [[i, 0, T] for i in range(n_utterances)], dtype=torch.int32) + dense = k2.DenseFsaVec(log_probs, segs) + if device != 'cpu': + dense = dense.to(device) + dense.scores.requires_grad_(True) + return dense + + def test_intersect_dense_pruned_2utterances(self): + """IntersectDensePruned with 2-utterance batch: scores match CPU.""" + fsa_cpu = self._make_fsa() + fsa_mps = self._make_fsa('mps') + dense_cpu = self._make_dense(2, seed=7) + dense_mps = self._make_dense(2, 'mps', seed=7) + + result_cpu = k2.intersect_dense_pruned( + fsa_cpu, dense_cpu, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000) + result_mps = k2.intersect_dense_pruned( + fsa_mps, dense_mps, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000) + + assert result_mps.scores.device.type == 'mps' + assert torch.allclose( + result_mps.scores.cpu(), result_cpu.scores, atol=1e-5) + + def test_intersect_dense_pruned_backward_2utterances(self): + """IntersectDensePruned 2-utterance backward: grads on MPS match CPU.""" + fsa_cpu = self._make_fsa() + fsa_mps = self._make_fsa('mps') + dense_cpu = self._make_dense(2, seed=7) + dense_mps = self._make_dense(2, 'mps', seed=7) + + result_cpu = k2.intersect_dense_pruned( + fsa_cpu, dense_cpu, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000) + result_mps = k2.intersect_dense_pruned( + fsa_mps, dense_mps, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000) + + result_cpu.scores.sum().backward() + result_mps.scores.sum().backward() + + assert fsa_mps.scores.grad.device.type == 'mps' + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + assert torch.allclose( + dense_mps.scores.grad.cpu(), dense_cpu.scores.grad, atol=1e-5) + + def test_intersect_dense_with_seqframe_attribute(self): + """seqframe attribute is set correctly on MPS output.""" + fsa_mps = self._make_fsa('mps') + dense_mps = self._make_dense(1, 'mps', seed=3) + + result = k2.intersect_dense_pruned( + fsa_mps, dense_mps, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000, + seqframe_idx_name='seqframe_idx') + + assert hasattr(result, 'seqframe_idx') + + def test_intersect_dense_seqframe_parity(self): + """seqframe/frame attributes match between CPU and MPS paths.""" + fsa_cpu = self._make_fsa() + fsa_mps = self._make_fsa('mps') + dense_cpu = self._make_dense(1, seed=3) + dense_mps = self._make_dense(1, 'mps', seed=3) + + result_cpu = k2.intersect_dense_pruned( + fsa_cpu, dense_cpu, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000, + seqframe_idx_name='seqframe_idx', + frame_idx_name='frame_idx') + result_mps = k2.intersect_dense_pruned( + fsa_mps, dense_mps, + search_beam=20.0, output_beam=8.0, + min_active_states=30, max_active_states=10000, + seqframe_idx_name='seqframe_idx', + frame_idx_name='frame_idx') + + assert torch.equal( + result_mps.seqframe_idx.cpu(), result_cpu.seqframe_idx) + assert torch.equal(result_mps.frame_idx.cpu(), result_cpu.frame_idx) + + def test_intersect_dense_function_parity(self): + """IntersectDense (non-pruned) forward/backward match CPU.""" + fsa_cpu = self._make_fsa() + fsa_mps = self._make_fsa('mps') + dense_cpu = self._make_dense(1, seed=5) + dense_mps = self._make_dense(1, 'mps', seed=5) + + result_cpu = k2.intersect_dense(fsa_cpu, dense_cpu, output_beam=100.0) + result_mps = k2.intersect_dense(fsa_mps, dense_mps, output_beam=100.0) + + assert torch.allclose( + result_mps.scores.cpu(), result_cpu.scores, atol=1e-5) + + result_cpu.scores.sum().backward() + result_mps.scores.sum().backward() + + assert torch.allclose( + fsa_mps.scores.grad.cpu(), fsa_cpu.scores.grad, atol=1e-5) + assert torch.allclose( + dense_mps.scores.grad.cpu(), dense_cpu.scores.grad, atol=1e-5) + + +@mps_available +class TestMpsMutualInformationExtended: + """Extended mutual_information tests for MPS (Priority 1).""" + + def test_mutual_information_varied_sizes(self): + """mutual_information with several (S, T) sizes matches CPU.""" + for S, T in [(2, 3), (5, 10), (10, 5), (20, 30)]: + px = torch.randn(1, S, T).requires_grad_(True) + py = torch.randn(1, S + 1, T).requires_grad_(True) + px_mps = px.detach().to('mps').requires_grad_(True) + py_mps = py.detach().to('mps').requires_grad_(True) + + mi_cpu = k2.mutual_information_recursion(px, py) + mi_mps = k2.mutual_information_recursion(px_mps, py_mps) + + assert torch.allclose(mi_mps.cpu(), mi_cpu, atol=1e-4), ( + f"Failed at S={S}, T={T}: " + f"MPS={mi_mps.item():.4f} CPU={mi_cpu.item():.4f}" + ) + + def test_mutual_information_gradient_varied_sizes(self): + """mutual_information backward is correct for several (S, T) sizes.""" + # Sizes where S >= T trigger a pre-existing k2 CPU MI backward warning + # that leaves CPU gradients as zero (upstream issue unrelated to MPS). + # Only test shapes where S < T to get reliable CPU reference gradients. + for S, T in [(3, 4), (5, 7), (4, 8)]: + px = torch.randn(1, S, T).requires_grad_(True) + py = torch.randn(1, S + 1, T).requires_grad_(True) + px_mps = px.detach().to('mps').requires_grad_(True) + py_mps = py.detach().to('mps').requires_grad_(True) + + k2.mutual_information_recursion(px, py).backward() + k2.mutual_information_recursion(px_mps, py_mps).backward() + + assert torch.allclose(px_mps.grad.cpu(), px.grad, atol=1e-4), \ + f"px.grad mismatch at S={S}, T={T}" + assert torch.allclose(py_mps.grad.cpu(), py.grad, atol=1e-4), \ + f"py.grad mismatch at S={S}, T={T}" + + def test_mutual_information_batch(self): + """Batch of B sequences: results match CPU for each item.""" + B, S, T = 4, 5, 8 + px = torch.randn(B, S, T).requires_grad_(True) + py = torch.randn(B, S + 1, T).requires_grad_(True) + px_mps = px.detach().to('mps').requires_grad_(True) + py_mps = py.detach().to('mps').requires_grad_(True) + + mi_cpu = k2.mutual_information_recursion(px, py) + mi_mps = k2.mutual_information_recursion(px_mps, py_mps) + + assert torch.allclose(mi_mps.cpu(), mi_cpu, atol=1e-4) + + mi_cpu.sum().backward() + mi_mps.sum().backward() + assert torch.allclose(px_mps.grad.cpu(), px.grad, atol=1e-4) + assert torch.allclose(py_mps.grad.cpu(), py.grad, atol=1e-4) + + def test_mutual_information_with_boundary(self): + """mutual_information with explicit boundary tensor matches CPU.""" + S, T = 6, 8 + px = torch.randn(1, S, T) + py = torch.randn(1, S + 1, T) + boundary = torch.tensor([[0, 0, S, T]], dtype=torch.int64) + + mi_cpu = k2.mutual_information_recursion(px, py, boundary=boundary) + mi_mps = k2.mutual_information_recursion( + px.to('mps'), py.to('mps'), boundary=boundary.to('mps')) + + assert torch.allclose(mi_mps.cpu(), mi_cpu, atol=1e-4) + + +@mps_available +class TestMpsForwardScoresNumericalStress: + """Numerical stress tests: larger FSAs, various topologies, precision.""" + + def _random_dag_fsa(self, n_states, n_arcs, seed=42, device='cpu'): + """Generate a connected random DAG FSA on device. + + A linear backbone guarantees every state is reachable and + co-reachable; extra skip arcs add branching for stress testing. + """ + torch.manual_seed(seed) + arcs = [] + # Backbone: guaranteed linear chain 0→1→…→(n-2)→(n-1, final arc). + for s in range(n_states - 2): + w = torch.randn(1).item() + arcs.append(f'{s} {s + 1} 1 {w:.4f}') + arcs.append(f'{n_states - 2} {n_states - 1} -1 0.0') + # Extra skip arcs between non-final states only (keeps validity). + extra = max(0, n_arcs - (n_states - 1)) + for _ in range(extra): + src = int(torch.randint(0, n_states - 2, (1,)).item()) + dst = int(torch.randint(src + 1, n_states - 1, (1,)).item()) + w = torch.randn(1).item() + arcs.append(f'{src} {dst} 1 {w:.4f}') + # Deduplicate and sort. + arc_lines = list(dict.fromkeys(a for a in arcs)) + arc_lines.sort(key=lambda a: (int(a.split()[0]), int(a.split()[1]))) + fsa_str = '\n'.join(arc_lines + [str(n_states - 1)]) + fsa = k2.Fsa.from_str(fsa_str) + return k2.create_fsa_vec([fsa]).to(device) + + def test_medium_fsa_tropical_parity(self): + """50-state random DAG: MPS tropical forward scores match CPU.""" + fsa_cpu = self._random_dag_fsa(50, 100) + fsa_mps = self._random_dag_fsa(50, 100, device='mps') + fwd_cpu = fsa_cpu._get_forward_scores(False, False) + fwd_mps = fsa_mps._get_forward_scores(False, False) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-4) + + def test_medium_fsa_log_parity(self): + """50-state random DAG: MPS log-semiring forward scores match CPU.""" + fsa_cpu = self._random_dag_fsa(50, 100) + fsa_mps = self._random_dag_fsa(50, 100, device='mps') + fwd_cpu = fsa_cpu._get_forward_scores(False, True) + fwd_mps = fsa_mps._get_forward_scores(False, True) + assert torch.allclose(fwd_mps.cpu(), fwd_cpu, atol=1e-4) + + def test_gradient_precision_with_repeated_scores(self): + """Repeated arc scores should not cause NaN in MPS gradients.""" + # All arcs have the same score → softmax-like gradient should sum to 1. + fsa_str = '0 1 1 0.0\n0 1 2 0.0\n1 2 -1 0.0\n2' + fsa_cpu = _make_fsa_vec(fsa_str) + fsa_mps = _make_fsa_vec(fsa_str, 'mps') + fsa_cpu.scores.requires_grad_(True) + fsa_mps.scores.requires_grad_(True) + + tot_cpu = fsa_cpu.get_tot_scores( + use_double_scores=False, log_semiring=True) + tot_mps = fsa_mps.get_tot_scores( + use_double_scores=False, log_semiring=True) + + tot_cpu.backward() + tot_mps.backward() + + grad_mps = fsa_mps.scores.grad.cpu() + assert not torch.any(torch.isnan(grad_mps)), "NaN in MPS gradient" + assert torch.allclose(grad_mps, fsa_cpu.scores.grad, atol=1e-5)