Skip to content

Commit cdc2c30

Browse files
committed
cleaning up torch and mori deps
1 parent e3cab4b commit cdc2c30

File tree

13 files changed

+150
-48
lines changed

13 files changed

+150
-48
lines changed

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2222
option(USE_ROCM "Whether to use rocm" ON)
2323
option(USE_BNXT "Whether to use BNXT NIC" OFF)
2424
option(USE_IONIC "Whether to use IONIC" OFF)
25+
option(FORCE_CODE_OBJECT_VERSION_5 "Force code object version to Triton" ON)
2526
option(BUILD_EXAMPLES "Whether to build examples" ON)
2627
option(BUILD_APPLICATION "Whether to build application library" ON)
2728
option(BUILD_SHMEM "Whether to build shmem library" ON)
2829
option(BUILD_OPS "Whether to build mori operation kernels" ON)
2930
option(BUILD_IO "Whether to build mori io library" ON)
3031
option(BUILD_PYBINDS "Whether to build mori python bindings" ON)
3132
option(BUILD_TESTS "Whether to build mori CPP tests" ON)
33+
option(ENABLE_TORCH "Whether to enable Torch bootstrap and python bindings" ON)
34+
option(ENABLE_MPI "Whether to enable MPI bootstrap" ON)
3235
option(ENABLE_PROFILER "Enable kernel profiling" OFF)
3336
option(ENABLE_DEBUG_PRINTF "Enable debug printf in device kernels" OFF)
3437
option(ENABLE_STANDARD_MOE_ADAPT "Enable standard moe adapt" OFF)
@@ -67,6 +70,16 @@ if(USE_IONIC)
6770
add_compile_definitions(ENABLE_IONIC)
6871
endif()
6972

73+
message(STATUS "ENABLE_TORCH = ${ENABLE_TORCH}")
74+
if(ENABLE_TORCH)
75+
add_compile_definitions(ENABLE_TORCH)
76+
endif()
77+
78+
message(STATUS "ENABLE_MPI = ${ENABLE_MPI}")
79+
if(ENABLE_MPI)
80+
add_compile_definitions(ENABLE_MPI)
81+
endif()
82+
7083
message(STATUS "ENABLE_PROFILER = ${ENABLE_PROFILER}")
7184
if(ENABLE_PROFILER)
7285
add_compile_definitions(ENABLE_PROFILER)
@@ -88,6 +101,9 @@ endif()
88101
message(STATUS "WARP_ACCUM_UNROLL is set to: ${WARP_ACCUM_UNROLL}")
89102
add_definitions(-DWARP_ACCUM_UNROLL=${WARP_ACCUM_UNROLL})
90103
add_definitions(-DHIP_ENABLE_WARP_SYNC_BUILTINS)
104+
105+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-literal-operator")
106+
91107
if(USE_ROCM)
92108
list(APPEND CMAKE_PREFIX_PATH "/opt/rocm")
93109
project(mori LANGUAGES HIP CXX C)

include/mori/application/bootstrap/bootstrap.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#pragma once
2323

2424
#include "mori/application/bootstrap/base_bootstrap.hpp"
25+
#ifdef ENABLE_MPI
2526
#include "mori/application/bootstrap/mpi_bootstrap.hpp"
27+
#endif
28+
#ifdef ENABLE_TORCH
2629
#include "mori/application/bootstrap/torch_bootstrap.hpp"
30+
#endif
2731
#include "mori/application/bootstrap/socket_bootstrap.hpp"

include/mori/shmem/shmem_api.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
// SOFTWARE.
2222
#pragma once
2323

24+
#ifdef ENABLE_MPI
2425
#include <mpi.h>
26+
#endif
2527

2628
#include <array>
2729
#include <cstddef>
@@ -51,10 +53,13 @@ constexpr unsigned int MORI_SHMEM_INIT_WITH_UNIQUEID = 1;
5153

5254
// TODO: provide unified initialize / finalize APIs
5355
int ShmemInit(application::BootstrapNetwork* bootNet);
56+
#ifdef ENABLE_MPI
5457
int ShmemInit(); // Default initialization using MPI_COMM_WORLD
5558
int ShmemMpiInit(MPI_Comm);
59+
#endif // ENABLE_MPI
60+
#ifdef ENABLE_TORCH
5661
int ShmemTorchProcessGroupInit(const std::string& groupName);
57-
62+
#endif // ENABLE_TORCH
5863
// UniqueId-based initialization APIs (nvshmem/rocshmem compatible)
5964
int ShmemGetUniqueId(mori_shmem_uniqueid_t* uid);
6065
int ShmemSetAttrUniqueIdArgs(int rank, int nranks, mori_shmem_uniqueid_t* uid,

include/mori/shmem/shmem_device_api.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#pragma once
2323

2424
#include <assert.h>
25-
#include <mpi.h>
2625

2726
#include "mori/application/application.hpp"
2827
#include "mori/core/core.hpp"

include/mori/shmem/shmem_ibgda_kernels.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#pragma once
2323

2424
#include <assert.h>
25-
#include <mpi.h>
2625

2726
#include "mori/application/application.hpp"
2827
#include "mori/core/core.hpp"

include/mori/shmem/shmem_p2p_kernels.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
#pragma once
2323

2424
#include <assert.h>
25-
#include <mpi.h>
26-
2725
#include <type_traits>
2826

2927
#include "mori/application/application.hpp"

include/mori/shmem/shmem_sdma_kernels.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
#pragma once
2323

2424
#include <assert.h>
25-
#include <mpi.h>
26-
2725
#include <type_traits>
2826

2927
#include "mori/application/application.hpp"

src/application/CMakeLists.txt

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
1-
find_package(MPI REQUIRED)
1+
if(ENABLE_MPI)
2+
find_package(MPI REQUIRED)
3+
endif()
24
find_package(hsa-runtime64 REQUIRED)
35
find_package(hsakmt REQUIRED)
46
#find_library(IONIC_LIBRARY
57
# NAMES ionic
68
# HINTS /lib/x86_64-linux-gnu
79
# REQUIRED
810
#)
9-
execute_process(
10-
COMMAND python -c "import torch; print(torch.utils.cmake_prefix_path)"
11-
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
12-
OUTPUT_VARIABLE TORCH_DIR
13-
OUTPUT_STRIP_TRAILING_WHITESPACE)
14-
cmake_path(SET TORCH_CMAKE_DIR NORMALIZE "${TORCH_DIR}/Torch")
15-
list(APPEND CMAKE_PREFIX_PATH ${TORCH_CMAKE_DIR})
16-
message(STATUS "Found LibTorch CMake Path: ${CMAKE_PREFIX_PATH}")
1711

18-
find_package(Torch REQUIRED)
12+
if(ENABLE_TORCH)
13+
execute_process(
14+
COMMAND python -c "import torch; print(torch.utils.cmake_prefix_path)"
15+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
16+
OUTPUT_VARIABLE TORCH_DIR
17+
OUTPUT_STRIP_TRAILING_WHITESPACE)
18+
cmake_path(SET TORCH_CMAKE_DIR NORMALIZE "${TORCH_DIR}/Torch")
19+
list(APPEND CMAKE_PREFIX_PATH ${TORCH_CMAKE_DIR})
20+
message(STATUS "Found LibTorch CMake Path: ${CMAKE_PREFIX_PATH}")
21+
find_package(Torch REQUIRED)
22+
endif()
1923

2024
add_library(
2125
mori_application SHARED
22-
bootstrap/mpi_bootstrap.cpp
23-
bootstrap/torch_bootstrap.cpp
2426
bootstrap/socket_bootstrap.cpp
2527
bootstrap/local_bootstrap.cpp
2628
transport/rdma/rdma.cpp
@@ -40,11 +42,18 @@ add_library(
4042
topology/pci.cpp
4143
topology/system.cpp)
4244

45+
if(ENABLE_TORCH)
46+
target_sources(mori_application PRIVATE bootstrap/torch_bootstrap.cpp)
47+
endif()
48+
49+
if(ENABLE_MPI)
50+
target_sources(mori_application PRIVATE bootstrap/mpi_bootstrap.cpp)
51+
endif()
52+
4353
target_include_directories(mori_application PUBLIC ${CMAKE_SOURCE_DIR}/include)
4454
target_include_directories(mori_application PUBLIC ${CMAKE_SOURCE_DIR})
4555
target_link_libraries(
4656
mori_application
47-
MPI::MPI_CXX
4857
ibverbs
4958
hip::host
5059
hip::device
@@ -61,8 +70,14 @@ if(USE_BNXT)
6170
endif()
6271

6372
if(USE_IONIC)
64-
target_link_libraries(mori_application ${IONIC_LIB})
73+
target_link_libraries(mori_application ${IONIC_LIB})
74+
endif()
75+
76+
if(ENABLE_MPI)
77+
target_link_libraries(mori_application MPI::MPI_CXX)
6578
endif()
6679

67-
target_include_directories(mori_application PUBLIC ${TORCH_INCLUDE_DIRS})
68-
target_link_libraries(mori_application c10 torch torch_cpu c10_hip torch_hip)
80+
if(ENABLE_TORCH)
81+
target_include_directories(mori_application PUBLIC ${TORCH_INCLUDE_DIRS})
82+
target_link_libraries(mori_application c10 torch torch_cpu c10_hip torch_hip)
83+
endif()

src/pybind/CMakeLists.txt

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
find_package(Torch REQUIRED)
1+
if (ENABLE_TORCH)
2+
find_package(Torch REQUIRED)
3+
endif()
24

35
# You can set PYTHON_EXECUTABLE via command line: -DPYTHON_EXECUTABLE=/path/to/python
46
if(NOT DEFINED PYTHON_EXECUTABLE)
@@ -37,21 +39,23 @@ if(ENABLE_PROFILER)
3739
endif()
3840

3941
target_include_directories(
40-
mori_pybinds PUBLIC ${PYTHON_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}
41-
${CMAKE_BINARY_DIR}/generated/include)
42-
target_link_directories(mori_pybinds PUBLIC ${TORCH_INSTALL_PREFIX}/lib)
42+
mori_pybinds PUBLIC ${PYTHON_INCLUDE_DIRS} ${CMAKE_BINARY_DIR}/generated/include)
4343
target_link_libraries(
4444
mori_pybinds
4545
mori_ops
4646
mori_io
47-
${TORCH_LIBRARIES}
48-
torch_python
4947
hip::host
5048
hip::device)
5149

52-
# For python packages to find dependent libraries
53-
set_target_properties(
54-
mori_pybinds
55-
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../torch/lib"
50+
if(ENABLE_TORCH)
51+
target_include_directories(
52+
mori_pybinds PUBLIC ${TORCH_INCLUDE_DIRS})
53+
target_link_directories(mori_pybinds PUBLIC ${TORCH_INSTALL_PREFIX}/lib)
54+
target_link_libraries(${TORCH_LIBRARIES} torch_python)
55+
# For python packages to find dependent libraries
56+
set_target_properties(
57+
mori_pybinds
58+
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../torch/lib"
5659
INSTALL_RPATH "$ORIGIN;$ORIGIN/../torch/lib"
5760
BUILD_WITH_INSTALL_RPATH TRUE)
61+
endif()

src/shmem/CMakeLists.txt

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
find_package(MPI REQUIRED)
1+
if(ENABLE_MPI)
2+
find_package(MPI REQUIRED)
3+
endif()
24

35
option(BUILD_SHMEM_DEVICE_WRAPPER "Build SHMEM device API wrapper" OFF)
46

@@ -8,21 +10,33 @@ else()
810
add_library(mori_shmem init.cpp memory.cpp)
911
endif()
1012

11-
target_include_directories(mori_shmem PUBLIC ${CMAKE_SOURCE_DIR}/include)
12-
target_include_directories(mori_shmem PUBLIC ${CMAKE_SOURCE_DIR})
13+
target_include_directories(mori_shmem PUBLIC ${CMAKE_SOURCE_DIR}
14+
${CMAKE_SOURCE_DIR}/include)
1315
target_link_libraries(
1416
mori_shmem
1517
mori_application
1618
mori_logging
17-
MPI::MPI_CXX
1819
ibverbs
1920
hip::host
2021
hip::device
2122
mlx5)
2223

23-
# Use code object version 5 for compatibility with Triton
24-
target_compile_options(mori_shmem PUBLIC "-fgpu-rdc" "-mcode-object-version=5")
25-
target_link_options(mori_shmem PUBLIC "-fgpu-rdc" "-mcode-object-version=5")
24+
if(ENABLE_MPI)
25+
target_link_libraries(mori_shmem MPI::MPI_CXX)
26+
endif()
27+
28+
target_compile_options(mori_shmem PUBLIC "-fgpu-rdc")
29+
target_link_options(mori_shmem PUBLIC "-fgpu-rdc")
30+
if(FORCE_CODE_OBJECT_VERSION_5)
31+
# Use code object version 5 for compatibility with Triton
32+
target_compile_options(mori_shmem PUBLIC "-mcode-object-version=5")
33+
target_link_options(mori_shmem PUBLIC "-mcode-object-version=5")
34+
endif()
35+
36+
if(NOT ENABLE_TORCH)
37+
# This seems to be required for compilation for JAX
38+
set_property(TARGET mori_shmem PROPERTY POSITION_INDEPENDENT_CODE ON)
39+
endif()
2640

2741
option(SAVE_TEMPS "Save intermediate compilation files" ON)
2842
if(SAVE_TEMPS)

0 commit comments

Comments
 (0)