Skip to content

Commit 772781e

Browse files
Ahead-of-time build and PyTorch Stable ABI (#184)
* Making progress. * More progress. * Saving temporarily. * Fixed JIT issue. * Managed to build stable extension. * Made some further changes. * Fixed the dynamic versioning bugs. * Back to a working state. * Ready to begin the import testing process. * Temporary save. * Fixed some more details about the C++ backend. * Even more things working. * Ready to test on HIP. * Minor comment fix. * Working on AOTI update. * AOTI loading works. * Added careful conditions about when to use a precompiled extension. * Added detailed warning messages. * Updated CI. * Tried defining symbol. * Avoided symbol name mangling. * Ruff. * Removed accelerator.h from original library. * Updated warning message. * Updated upgrade message. * Updated PyTorch version in CI. * Updated CI script. * Updated JAX dependency list.
1 parent 5f33074 commit 772781e

24 files changed

+1297
-786
lines changed

.github/workflows/requirements_cuda_ci.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy==2.2.5
2-
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
2+
torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128
33
pytest==8.3.5
44
ninja==1.11.1.4
55
nanobind==2.10.2

.github/workflows/verify_extension_build.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ jobs:
2929
sudo apt-get update
3030
sudo apt install nvidia-cuda-toolkit
3131
pip install -r .github/workflows/requirements_cuda_ci.txt
32-
pip install -e "./openequivariance"
32+
pip install -e "./openequivariance[jax]"
3333
3434
- name: Test CUDA extension build via import
3535
run: |
36-
pytest \
37-
tests/import_test.py::test_extension_built \
38-
tests/import_test.py::test_torch_extension_built
36+
pytest tests/import_test.py
37+
38+
export OEQ_JIT_EXTENSION=1
39+
40+
pytest tests/import_test.py
3941
4042
- name: Test JAX extension build
4143
run: |

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ __pycache__
99
# working folders
1010
dist
1111
build
12+
cbuild
1213
outputs/*
1314
visualization/*
1415
figures/*
@@ -40,4 +41,5 @@ paper_benchmarks_v2
4041
paper_benchmarks_v3
4142

4243
get_node.sh
43-
*.egg-info
44+
*.egg-info
45+
_version.py

openequivariance/CMakeLists.txt

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
cmake_minimum_required(VERSION 3.15...3.30)
2+
project(openequivariance_stable_ext)
3+
4+
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
5+
6+
# Download LibTorch
7+
include(FetchContent)
8+
9+
FetchContent_Declare(
10+
libtorch
11+
URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.10.0%2Bcpu.zip"
12+
)
13+
14+
message(STATUS "Downloading LibTorch...")
15+
FetchContent_MakeAvailable(libtorch)
16+
17+
set(LIBTORCH_INCLUDE_DIR "${libtorch_SOURCE_DIR}/include")
18+
set(LIBTORCH_LIB_DIR "${libtorch_SOURCE_DIR}/lib")
19+
find_library(TORCH_CPU_LIB NAMES torch_cpu PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH)
20+
find_library(C10_LIB NAMES c10 PATHS "${LIBTORCH_LIB_DIR}" NO_DEFAULT_PATH)
21+
22+
message(STATUS "LibTorch Include: ${LIBTORCH_INCLUDE_DIR}")
23+
message(STATUS "LibTorch Lib: ${LIBTORCH_LIB_DIR}")
24+
25+
message(STATUS "Torch CPU Library: ${TORCH_CPU_LIB}")
26+
message(STATUS "Torch C10 Library: ${C10_LIB}")
27+
28+
# Setup Nanobind
29+
execute_process(
30+
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
31+
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT
32+
)
33+
message(STATUS "nanobind cmake directory: ${nanobind_ROOT}")
34+
35+
find_package(nanobind CONFIG REQUIRED)
36+
37+
set(EXT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/extension")
38+
set(EXT_BACKEND_DIR "${EXT_DIR}/backend")
39+
set(EXT_JSON_DIR "${EXT_DIR}/json11")
40+
41+
# Source files
42+
set(OEQ_SOURCES
43+
${EXT_DIR}/libtorch_tp_jit_stable.cpp
44+
${EXT_JSON_DIR}/json11.cpp
45+
)
46+
47+
set(OEQ_INSTALL_DIR "${CMAKE_CURRENT_SOURCE_DIR}/openequivariance/_torch/extlib")
48+
49+
function(add_stable_extension target_name backend_define link_libraries)
50+
# Create nanobind extension
51+
nanobind_add_module(${target_name} NB_STATIC ${OEQ_SOURCES})
52+
53+
set_target_properties(${target_name} PROPERTIES
54+
CXX_STANDARD 17
55+
CXX_STANDARD_REQUIRED ON
56+
POSITION_INDEPENDENT_CODE ON
57+
)
58+
59+
# Enforce CXX11 ABI to match LibTorch
60+
target_compile_definitions(${target_name} PRIVATE
61+
${backend_define}=1
62+
_GLIBCXX_USE_CXX11_ABI=1
63+
INCLUDE_NB_EXTENSION
64+
)
65+
66+
target_include_directories(${target_name} PRIVATE
67+
${EXT_DIR}
68+
${EXT_BACKEND_DIR}
69+
${EXT_JSON_DIR}
70+
${LIBTORCH_INCLUDE_DIR}
71+
)
72+
target_link_libraries(${target_name} PRIVATE
73+
${TORCH_CPU_LIB}
74+
${C10_LIB}
75+
${link_libraries}
76+
)
77+
78+
install(TARGETS ${target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}")
79+
80+
# AOTI C++ library (identical except without nanobind and without INCLUDE_NB_EXTENSION)
81+
set(aoti_target_name ${target_name}_aoti)
82+
add_library(${aoti_target_name} SHARED ${OEQ_SOURCES})
83+
84+
set_target_properties(${aoti_target_name} PROPERTIES
85+
CXX_STANDARD 17
86+
CXX_STANDARD_REQUIRED ON
87+
POSITION_INDEPENDENT_CODE ON
88+
)
89+
90+
target_compile_definitions(${aoti_target_name} PRIVATE
91+
${backend_define}=1
92+
_GLIBCXX_USE_CXX11_ABI=1
93+
)
94+
95+
target_include_directories(${aoti_target_name} PRIVATE
96+
${EXT_DIR}
97+
${EXT_BACKEND_DIR}
98+
${EXT_JSON_DIR}
99+
${LIBTORCH_INCLUDE_DIR}
100+
)
101+
target_link_libraries(${aoti_target_name} PRIVATE
102+
${TORCH_CPU_LIB}
103+
${C10_LIB}
104+
${link_libraries}
105+
)
106+
107+
install(TARGETS ${aoti_target_name} LIBRARY DESTINATION "${OEQ_INSTALL_DIR}")
108+
endfunction()
109+
110+
find_package(CUDAToolkit QUIET)
111+
find_package(hip QUIET)
112+
113+
if(CUDAToolkit_FOUND)
114+
message(STATUS "Building stable extension with CUDA backend.")
115+
116+
add_library(cuda_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp)
117+
118+
target_include_directories(cuda_stub_lib PRIVATE
119+
${LIBTORCH_INCLUDE_DIR}
120+
)
121+
122+
set_target_properties(cuda_stub_lib PROPERTIES
123+
OUTPUT_NAME "torch_cuda"
124+
POSITION_INDEPENDENT_CODE ON
125+
CXX_STANDARD 17
126+
)
127+
128+
set(CUDA_LINK_LIBS
129+
CUDA::cudart
130+
CUDA::cuda_driver
131+
CUDA::nvrtc
132+
cuda_stub_lib
133+
)
134+
add_stable_extension(oeq_stable_cuda CUDA_BACKEND "${CUDA_LINK_LIBS}")
135+
endif()
136+
137+
if(hip_FOUND)
138+
message(STATUS "Building stable extension with HIP backend.")
139+
140+
add_library(hip_stub_lib SHARED ${EXT_DIR}/stubs/stream.cpp)
141+
142+
target_include_directories(hip_stub_lib PRIVATE
143+
${LIBTORCH_INCLUDE_DIR}
144+
)
145+
146+
set_target_properties(hip_stub_lib PROPERTIES
147+
OUTPUT_NAME "torch_hip"
148+
POSITION_INDEPENDENT_CODE ON
149+
CXX_STANDARD 17
150+
)
151+
152+
set(HIP_LINK_LIBS
153+
hiprtc
154+
hip_stub_lib
155+
)
156+
add_stable_extension(torch_stable_hip HIP_BACKEND "${HIP_LINK_LIBS}")
157+
endif()
158+
159+
if(NOT CUDAToolkit_FOUND AND NOT hip_FOUND)
160+
message(WARNING "Neither CUDAToolkit nor HIP was found. The stable extension will not be built.")
161+
endif()

openequivariance/openequivariance/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def _check_package_editable():
6161
LINKED_LIBPYTHON_ERROR,
6262
BUILT_EXTENSION,
6363
BUILT_EXTENSION_ERROR,
64-
TORCH_COMPILE,
65-
TORCH_COMPILE_ERROR,
64+
USE_PRECOMPILED_EXTENSION,
6665
)
6766

6867

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def register_autocast():
228228
)
229229

230230

231-
register_torch_fakes()
232-
register_autograd()
233-
register_autocast()
231+
if extlib.BUILT_EXTENSION:
232+
register_torch_fakes()
233+
register_autograd()
234+
register_autocast()

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from openequivariance._torch.extlib import (
77
postprocess_kernel,
88
DeviceProp,
9+
BUILT_EXTENSION,
910
)
1011

1112
from openequivariance.core.ConvolutionBase import (
@@ -403,9 +404,10 @@ def register_autocast():
403404
)
404405

405406

406-
register_torch_fakes()
407-
register_autograd()
408-
register_autocast()
407+
if BUILT_EXTENSION:
408+
register_torch_fakes()
409+
register_autograd()
410+
register_autocast()
409411

410412

411413
# ==================================================================

0 commit comments

Comments
 (0)