Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions .github/workflows/build-mps-macos.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 k2 contributors

# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: build-mps-macos

on:
push:
branches:
- master
paths:
- '.github/workflows/build-mps-macos.yml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'k2/csrc/**'
- 'k2/python/**'
pull_request:
types: [labeled]
paths:
- '.github/workflows/build-mps-macos.yml'
- 'CMakeLists.txt'
- 'cmake/**'
- 'k2/csrc/**'
- 'k2/python/**'
workflow_dispatch:

concurrency:
group: build-mps-macos-${{ github.ref }}
cancel-in-progress: true

env:
BUILD_TYPE: Release

jobs:
build-mps-macos:
if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'workflow_dispatch'
name: Build (macOS, CPU+MPS)
runs-on: macos-14 # Apple Silicon M1/M2 — MPS is available

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Display clang version
run: clang --version

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Display Python version
run: python -c "import sys; print(sys.version)"

- name: Install PyTorch (macOS arm64 with MPS)
shell: bash
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq wheel twine
python3 -m pip install -qq torch --index-url https://download.pytorch.org/whl/cpu
python3 -c "import torch; print('torch version:', torch.__version__); print('MPS available:', torch.backends.mps.is_available())"

- name: Build wheel
shell: bash
run: |
export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF -DK2_WITH_MPS=ON"
export K2_MAKE_ARGS="-j3"
python3 setup.py bdist_wheel
ls -lh dist/
ls -lh build/*

- name: Upload wheel
uses: actions/upload-artifact@v4
with:
name: k2-macos14-mps-python3.11
path: dist/*.whl

- name: Install wheel and verify MPS support
shell: bash
run: |
pip install dist/*.whl
python3 -m k2.version
python3 -c "
import k2
assert k2.with_mps, 'k2 was not built with MPS support'
print('k2.with_cuda:', k2.with_cuda)
print('k2.with_mps:', k2.with_mps)
print('MPS build verified.')
"

- name: Run MPS tests
shell: bash
env:
PYTORCH_ENABLE_MPS_FALLBACK: "1"
run: |
pip install pytest
python3 -m pytest k2/python/tests/test_mps.py -v --tb=short

- name: Run CPU tests (sanity check)
shell: bash
run: |
python3 -m pytest k2/python/tests/ -v --ignore=k2/python/tests/test_mps.py -x --tb=short -q

- name: Build k2 (cmake + C++ tests)
shell: bash
run: |
mkdir -p build_cmake
cd build_cmake
cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF -DK2_WITH_MPS=ON ..
cat k2/csrc/version.h
make VERBOSE=1 -j3

- name: Run C++ tests (ctest)
shell: bash
run: |
cd build_cmake
ctest --output-on-failure
8 changes: 4 additions & 4 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,23 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: ["3.10"]
fail-fast: false

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 2

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip typing_extensions
python3 -m pip install --upgrade flake8==3.8.3
python3 -m pip install --upgrade flake8

- name: Run flake8
shell: bash
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,4 @@ dkms.conf
!.github/**
!k2/torch/bin
*-bak
.claude/
38 changes: 38 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,34 @@ if(APPLE OR (DEFINED K2_WITH_CUDA AND NOT K2_WITH_CUDA))
endif()
endif()

# Propagate PYTHON_EXECUTABLE → Python3_EXECUTABLE so all find_package(Python3)
# calls use the same interpreter (e.g. the active venv) rather than the system
# Python located by CMake's automatic search.
if(DEFINED PYTHON_EXECUTABLE AND NOT DEFINED Python3_EXECUTABLE)
set(Python3_EXECUTABLE "${PYTHON_EXECUTABLE}" CACHE FILEPATH
"Path to the Python 3 interpreter" FORCE)
endif()

# Detect MPS (Metal Performance Shaders) availability on macOS
set(_K2_WITH_MPS OFF)
if(APPLE)
find_package(Python3 QUIET COMPONENTS Interpreter)
if(Python3_FOUND)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c
"import torch; print(torch.backends.mps.is_available())"
OUTPUT_VARIABLE _K2_MPS_AVAILABLE
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET)
if(_K2_MPS_AVAILABLE STREQUAL "True")
message(STATUS "MPS is available -- enabling K2_WITH_MPS")
set(_K2_WITH_MPS ON)
else()
message(STATUS "MPS is NOT available on this machine")
endif()
endif()
endif()

if(_K2_WITH_CUDA)
set(languages ${languages} CUDA)
endif()
Expand Down Expand Up @@ -81,6 +109,7 @@ option(BUILD_SHARED_LIBS "Whether to build shared or static lib" ON)
option(K2_USE_PYTORCH "Whether to build with PyTorch" ON)
option(K2_ENABLE_BENCHMARK "Whether to enable benchmark" ON)
option(K2_WITH_CUDA "Whether to build k2 with CUDA" ${_K2_WITH_CUDA})
option(K2_WITH_MPS "Whether to build k2 with MPS (Apple Metal)" ${_K2_WITH_MPS})
option(K2_ENABLE_NVTX "Whether to build k2 with the NVTX library" ON)
option(K2_ENABLE_TESTS "Whether to build tests" ON)

Expand Down Expand Up @@ -334,6 +363,11 @@ endif()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)

# Ensure Python3 Development component is found before pybind11, so that
# python3_add_library is available (required by pybind11 NewTools with CMake 4+).
if(NOT Python3_Development_FOUND AND NOT Python3_Development.Module_FOUND)
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
endif()

include(pybind11)
if(K2_USE_PYTORCH)
Expand All @@ -350,6 +384,10 @@ if(K2_WITH_CUDA)
add_definitions(-DK2_WITH_CUDA)
endif()

if(K2_WITH_MPS)
add_definitions(-DK2_WITH_MPS)
endif()

if(WIN32)
add_definitions(-DNOMINMAX) # Otherwise, std::max() and std::min() won't work
endif()
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@ LF-MMI training. This won't give a direct advantage in terms of Word Error Rate
compared with existing technology; but the point is to do this in a much more
general and extensible framework to allow further development of ASR technology.

## Apple Silicon (MPS)

k2 supports Apple Silicon (M-series) via PyTorch's Metal Performance Shaders
(MPS) backend. To build with MPS enabled:

```bash
git clone https://github.com/k2-fsa/k2.git
cd k2
export K2_CMAKE_ARGS="-DK2_WITH_MPS=ON"
python3 setup.py install
```

You can verify MPS support is active with:

```python
import k2
print(k2.with_mps) # True
```

Hot paths (mutual information recursion, forward scores, associative scan) use
native Metal kernels. Topology-dependent operations run on CPU and return
MPS-resident tensors with a gradient-connected result.

## Implementation

A few key points on our implementation strategy.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ k2
k2 is able to seamlessly integrate Finite State
Automaton (FSA) and Finite State Transducer (FST) algorithms into
autograd-based machine learning toolkits like PyTorch [#f1]_.
k2 supports CPU as well as CUDA. It can process a batch of FSTs
at the same time.
k2 supports CPU, CUDA, and Apple Silicon via Metal Performance Shaders (MPS).
It can process a batch of FSTs at the same time.

.. [#f1] Support for TensorFlow will be added in the future.

Expand Down
33 changes: 32 additions & 1 deletion docs/source/installation/from_source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Install from source

.. hint::

It supports Linux (CPU + CUDA), macOS (CPU), and Windows (CPU + CUDA).
It supports Linux (CPU + CUDA), macOS (CPU + MPS), and Windows (CPU + CUDA).

.. hint::

Expand Down Expand Up @@ -51,6 +51,37 @@ After setting up the environment, we are ready to build k2:

That is all you need to run.

.. _install k2 with mps:

Building with Apple Silicon (MPS) support
------------------------------------------

On macOS with an M-series chip, you can enable the Metal Performance Shaders
backend by passing ``-DK2_WITH_MPS=ON`` to CMake:

.. code-block:: bash

git clone https://github.com/k2-fsa/k2.git
cd k2
export K2_CMAKE_ARGS="-DK2_WITH_MPS=ON"
export K2_MAKE_ARGS="-j6"
python3 setup.py install

To verify that MPS support is active:

.. code-block:: python

import k2
print(k2.with_mps) # True on Apple Silicon with MPS build

.. hint::

PyTorch >= 2.2 is required for MPS support. Float64 operations and a small
number of non-differentiable internal functions are not available on MPS;
the differentiable public API (``get_forward_scores``, ``get_tot_scores``,
``get_arc_post``, ``intersect_dense``, ``mutual_information_recursion``)
works fully on MPS.

.. hint::

We use ``export K2_MAKE_ARGS="-j6"`` to pass ``-j6`` to ``make``
Expand Down
8 changes: 8 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ if(K2_USE_PYTORCH)
if(DEFINED ENV{CONDA_PREFIX} AND APPLE)
target_link_libraries(context PUBLIC "-L $ENV{CONDA_PREFIX}/lib")
endif()

if(APPLE)
# On macOS with venv-style PyTorch installs, libtorch_python.dylib
# (which provides autograd metadata symbols) is loaded at Python runtime,
# not at link time. Allow undefined symbols so intermediate shared
# libraries (libk2context.dylib etc.) link successfully.
target_link_libraries(context PUBLIC "-undefined dynamic_lookup")
endif()
endif()

target_include_directories(context PUBLIC ${PYTHON_INCLUDE_DIRS})
Expand Down
4 changes: 3 additions & 1 deletion k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ ToType(int64_t, Long)
K2_CHECK_LT(i, Dim());
const T *data = Data() + i;
DeviceType type = Context()->GetDeviceType();
if (type == kCpu) {
if (type == kCpu || type == kMps) {
// MPS uses MTLStorageModeShared so data_ptr() is CPU-dereferenceable,
// but callers must ensure pending Metal writes are flushed (synchronize).
return *data;
} else {
K2_CHECK_EQ(type, kCuda);
Expand Down
Loading
Loading