Skip to content

Commit 30c4d01

Browse files
committed
Feat: add mori_shmem_barrier_all to wrapper
1 parent 618dfd4 commit 30c4d01

File tree

4 files changed

+14
-0
lines changed

4 files changed

+14
-0
lines changed

include/mori/shmem/shmem_device_api_wrapper.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ __device__ __attribute__((visibility("default"))) void mori_shmem_fence_thread_p
4444
__device__ __attribute__((visibility("default"))) void mori_shmem_fence_thread_pe_qp(int pe,
4545
int qpId);
4646

47+
__device__ __attribute__((visibility("default"))) void mori_shmem_barrier_all_thread();
48+
__device__ __attribute__((visibility("default"))) void mori_shmem_barrier_all_block();
49+
4750
// ============================================================================
4851
// PutNbi APIs - Thread Scope (Address-based only)
4952
// ============================================================================

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def build_extension(self, ext: Extension) -> None:
8686
build_type = os.environ.get("CMAKE_BUILD_TYPE", "Release")
8787
unroll_value = os.environ.get("WARP_ACCUM_UNROLL", "1")
8888
use_bnxt = os.environ.get("USE_BNXT", "OFF")
89+
build_shmem_device_wrapper = os.environ.get("BUILD_SHMEM_DEVICE_WRAPPER", "ON")
8990
use_ionic = os.environ.get("USE_IONIC", "OFF")
9091
enable_profiler = os.environ.get("ENABLE_PROFILER", "OFF")
9192
enable_debug_printf = os.environ.get("ENABLE_DEBUG_PRINTF", "OFF")
@@ -97,6 +98,7 @@ def build_extension(self, ext: Extension) -> None:
9798
f"-DCMAKE_BUILD_TYPE={build_type}",
9899
f"-DWARP_ACCUM_UNROLL={unroll_value}",
99100
f"-DUSE_BNXT={use_bnxt}",
101+
f"-DBUILD_SHMEM_DEVICE_WRAPPER={build_shmem_device_wrapper}",
100102
f"-DUSE_IONIC={use_ionic}",
101103
f"-DENABLE_DEBUG_PRINTF={enable_debug_printf}",
102104
f"-DGPU_TARGETS={gpu_archs}",

src/shmem/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ option(BUILD_SHMEM_DEVICE_WRAPPER "Build SHMEM device API wrapper" OFF)
44

55
if(BUILD_SHMEM_DEVICE_WRAPPER)
66
add_library(mori_shmem init.cpp memory.cpp shmem_device_api_wrapper.cpp)
7+
message(STATUS "BUILD_SHMEM_DEVICE_WRAPPER enabled for mori_shmem")
78
else()
89
add_library(mori_shmem init.cpp memory.cpp)
910
endif()

src/shmem/shmem_device_api_wrapper.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ __device__ __attribute__((visibility("default"))) void mori_shmem_fence_thread_p
5757
mori::shmem::ShmemFenceThread(pe, qpId);
5858
}
5959

60+
__device__ __attribute__((visibility("default"))) void mori_shmem_barrier_all_thread() {
61+
mori::shmem::ShmemBarrierAllThread();
62+
}
63+
64+
__device__ __attribute__((visibility("default"))) void mori_shmem_barrier_all_block() {
65+
mori::shmem::ShmemBarrierAllBlock();
66+
}
67+
6068
// ============================================================================
6169
// PutNbi APIs - Address-based only
6270
// ============================================================================

0 commit comments

Comments
 (0)