Skip to content

Commit db7cf6d

Browse files
authored
Feat: Enable intra-node FP4 dispatch and BF16 cast to FP8 combine (#169)
- Add BF16→FP8 direct-cast quantization in the intranode combine kernel. - Extend intranode dispatch/combine to support float4_e2m1fn_x2 (FP4). - Add RDMA/IO env vars: MORI_IO_SL, MORI_RDMA_SL, MORI_IO_TC_DISABLE, MORI_IB_ENABLE_RELAXED_ORDERING.
1 parent 95ad1dd commit db7cf6d

File tree

14 files changed

+441
-117
lines changed

14 files changed

+441
-117
lines changed

examples/ops/dispatch_combine/test_dispatch_combine.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,23 @@
2828

2929
os.environ["MORI_SHMEM_HEAP_SIZE"] = "6G"
3030

31+
32+
def _is_fp4x2_dtype(dtype):
33+
return dtype is torch.float4_e2m1fn_x2
34+
35+
3136
class EpDispatchCombineTestCase:
32-
def __init__(self, rank, world_size, dtype=torch.bfloat16):
37+
def __init__(self, rank, world_size, dtype=torch.bfloat16, quant_type="none", hidden_dim=7168):
3338
self.rank = rank
3439
self.world_size = world_size
40+
# fp8_direct_cast requires use_external_inp_buf=True (not zero-copy)
41+
use_external_inp_buf = (quant_type == "fp8_direct_cast")
42+
cfg_hidden_dim = hidden_dim // 2 if _is_fp4x2_dtype(dtype) else hidden_dim
3543
self.config = mori.ops.EpDispatchCombineConfig(
3644
data_type=dtype,
3745
rank=self.rank,
3846
world_size=self.world_size,
39-
hidden_dim=7168,
47+
hidden_dim=cfg_hidden_dim,
4048
# scale_dim=32,
4149
scale_dim=0,
4250
scale_type_size=torch.tensor(
@@ -46,7 +54,8 @@ def __init__(self, rank, world_size, dtype=torch.bfloat16):
4654
max_num_inp_token_per_rank=4096,
4755
num_experts_per_rank=32,
4856
num_experts_per_token=8,
49-
use_external_inp_buf=False,
57+
use_external_inp_buf=use_external_inp_buf,
58+
quant_type=quant_type,
5059
)
5160

5261
def setup(self):
@@ -177,10 +186,22 @@ def gen_test_data(self):
177186
generator=self.rng,
178187
device=self.device,
179188
)
189+
if _is_fp4x2_dtype(self.config.data_type):
190+
input_bytes = torch.randint(
191+
0,
192+
256,
193+
(num_tokens, self.config.hidden_dim),
194+
dtype=torch.uint8,
195+
generator=self.rng,
196+
device=self.device,
197+
)
198+
input = input_bytes.view(torch.float4_e2m1fn_x2)
199+
else:
200+
input = input_fp32.to(self.config.data_type)
201+
180202
input_list = self._allgather_with_token_num_padding(
181-
input_fp32, self.config.max_num_inp_token_per_rank
203+
input, self.config.max_num_inp_token_per_rank
182204
)
183-
input_list = [tensor.to(self.config.data_type) for tensor in input_list]
184205

185206
return (
186207
num_tokens,
@@ -189,7 +210,7 @@ def gen_test_data(self):
189210
# None,
190211
# scales_fp32,
191212
scales_fp32.to(torch.float8_e4m3fnuz),
192-
input_fp32.to(self.config.data_type),
213+
input,
193214
indices_list,
194215
weights_list,
195216
# None,
@@ -233,7 +254,13 @@ def run_test_once(self, op, test_data):
233254
for i, pos in enumerate(src_token_pos):
234255
src_rank = int(pos) // self.config.max_num_inp_token_per_rank
235256
src_id = int(pos) % self.config.max_num_inp_token_per_rank
236-
assert torch.equal(input_list[src_rank][src_id], dispatch_output[i])
257+
if _is_fp4x2_dtype(self.config.data_type):
258+
assert torch.equal(
259+
input_list[src_rank][src_id].view(torch.uint8),
260+
dispatch_output[i].view(torch.uint8),
261+
)
262+
else:
263+
assert torch.equal(input_list[src_rank][src_id], dispatch_output[i])
237264
assert torch.equal(weights_list[src_rank][src_id], dispatch_weights[i])
238265
if scales_list is not None and self.config.scale_dim != 0:
239266
assert torch.equal(scales_list[src_rank][src_id], dispatch_scales[i])
@@ -263,6 +290,8 @@ def run_test_once(self, op, test_data):
263290
torch.cuda.synchronize()
264291

265292
for i in range(num_tokens):
293+
# if _is_fp4x2_dtype(self.config.data_type):
294+
# continue
266295
pes = [
267296
(idx // self.config.num_experts_per_rank)
268297
for idx in indices[i].cpu().tolist()
@@ -274,7 +303,10 @@ def run_test_once(self, op, test_data):
274303
# ).to(self.config.data_type)
275304
got, expected = combine_output[i], input[i].to(torch.bfloat16) * unique_pes
276305

277-
assert torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2)
306+
atol, rtol = 1e-2, 1e-2
307+
if self.config.quant_type == "fp8_direct_cast":
308+
atol, rtol = 1e-1, 1e-1
309+
assert torch.allclose(got.float(), expected.float(), atol=atol, rtol=rtol)
278310

279311
got_weight, expected_weight = (
280312
combine_output_weight[i],
@@ -309,16 +341,45 @@ def test_dispatch_combine(self):
309341
del op
310342

311343

312-
def test_dispatch_combine(rank, world_size):
344+
def test_dispatch_combine(rank, world_size, dtype, quant_type="none"):
313345
# test_case = EpDispatchCombineTestCase(rank, world_size, torch.float8_e4m3fnuz)
314-
test_case = EpDispatchCombineTestCase(rank, world_size, torch.bfloat16)
346+
test_case = EpDispatchCombineTestCase(rank, world_size, dtype, quant_type)
315347
test_case.setup()
316348
test_case.test_dispatch_combine()
317349
test_case.cleanup()
318350

319351

320352
if __name__ == "__main__":
353+
import argparse
354+
355+
parser = argparse.ArgumentParser()
356+
parser.add_argument(
357+
"--dtype",
358+
type=str,
359+
default="bf16",
360+
choices=["bf16", "fp4"],
361+
help="Data type of dispatch / combine",
362+
)
363+
parser.add_argument(
364+
"--quant-type",
365+
type=str,
366+
default="none",
367+
choices=["none", "fp8_direct_cast"],
368+
help="Quantization method used inside Combine.",
369+
)
370+
args = parser.parse_args()
371+
372+
_DATA_TYPE_MAP = {
373+
"bf16": torch.bfloat16,
374+
"fp4": torch.float4_e2m1fn_x2,
375+
}
376+
if args.quant_type == "fp8_direct_cast" and _DATA_TYPE_MAP[args.dtype] is torch.float4_e2m1fn_x2:
377+
raise ValueError("fp8_direct_cast is not supported for fp4 data type")
378+
321379
world_size = 8
322380
torch.multiprocessing.spawn(
323-
test_dispatch_combine, args=(world_size,), nprocs=world_size, join=True
381+
test_dispatch_combine,
382+
args=(world_size, _DATA_TYPE_MAP[args.dtype], args.quant_type),
383+
nprocs=world_size,
384+
join=True,
324385
)

include/mori/application/transport/rdma/rdma.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ class RdmaDevice;
190190

191191
std::optional<uint8_t> ReadRdmaServiceLevelEnv();
192192
std::optional<uint8_t> ReadRdmaTrafficClassEnv();
193+
std::optional<uint8_t> ReadIoServiceLevelEnv();
194+
std::optional<uint8_t> ReadIoTrafficClassEnv();
195+
bool ReadIoTrafficClassDisableEnv();
196+
197+
bool ReadIbEnableRelaxedOrderingEnv();
198+
int MaybeAddRelaxedOrderingFlag(int accessFlag);
193199

194200
/* -------------------------------------------------------------------------- */
195201
/* RdmaDeviceContext */

include/mori/core/transport/p2p/device_primitives.hpp

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -796,11 +796,11 @@ __forceinline__ __device__ void WarpCastBf16ToCombineInternalFp8(
796796
}
797797
}
798798
}
799+
// Note: when T != hip_bfloat16, this function is a no-op.
800+
// Callers should guard with if constexpr or ensure T is hip_bfloat16.
799801
#else
800-
(void)dst;
801-
(void)src;
802-
(void)hiddenDim;
803-
(void)laneId;
802+
static_assert(!sizeof(T*), "WarpCastBf16ToCombineInternalFp8 requires FP8 type support "
803+
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
804804
#endif
805805
}
806806

@@ -809,31 +809,28 @@ namespace detail {
809809
using CombineInternalFp8T = CombineInternalFp8;
810810
using CombineInternalFp8x4T = CombineInternalFp8x4;
811811

812-
template <int NNodes>
813-
__forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed(
814-
hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const* __restrict__ srcPtrs,
815-
int laneId, int hiddenDimSize);
816-
817-
template <>
818-
__forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
812+
template <int AccumNum>
813+
__forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16Fixed(
819814
hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const* __restrict__ srcPtrs,
820815
int laneId, int hiddenDimSize) {
816+
static_assert(AccumNum > 0, "AccumNum must be positive");
817+
821818
using Fp8T = CombineInternalFp8T;
822819
using Fp8x4T = CombineInternalFp8x4T;
823820
constexpr int kVec8 = 8;
824821
constexpr int kVec4 = 4;
825822

826-
const Fp8T* src0 = srcPtrs[0];
827-
const Fp8T* src1 = srcPtrs[1];
828-
829823
const uintptr_t outAddr = reinterpret_cast<uintptr_t>(out);
830-
const uintptr_t src0Addr = reinterpret_cast<uintptr_t>(src0);
831-
const uintptr_t src1Addr = reinterpret_cast<uintptr_t>(src1);
832-
833-
const bool canVec8 = ((outAddr & 0x7) == 0) && ((src0 == nullptr) || ((src0Addr & 0x7) == 0)) &&
834-
((src1 == nullptr) || ((src1Addr & 0x7) == 0));
835-
const bool canVec4 = ((src0 == nullptr) || ((src0Addr & 0x3) == 0)) &&
836-
((src1 == nullptr) || ((src1Addr & 0x3) == 0));
824+
bool canVec8 = ((outAddr & 0x7) == 0);
825+
bool canVec4 = true;
826+
#pragma unroll
827+
for (int n = 0; n < AccumNum; n++) {
828+
const Fp8T* src = srcPtrs[n];
829+
if (src == nullptr) continue;
830+
const uintptr_t srcAddr = reinterpret_cast<uintptr_t>(src);
831+
canVec8 &= ((srcAddr & 0x7) == 0);
832+
canVec4 &= ((srcAddr & 0x3) == 0);
833+
}
837834

838835
const int vecEnd8 = (hiddenDimSize / kVec8) * kVec8;
839836
const int vecEnd4 = (hiddenDimSize / kVec4) * kVec4;
@@ -846,7 +843,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
846843
float4 sumLo = {0.0f, 0.0f, 0.0f, 0.0f};
847844
float4 sumHi = {0.0f, 0.0f, 0.0f, 0.0f};
848845
#pragma unroll
849-
for (int n = 0; n < 2; n++) {
846+
for (int n = 0; n < AccumNum; n++) {
850847
const Fp8T* src = srcPtrs[n];
851848
if (src == nullptr) continue;
852849
const auto* srcAligned = static_cast<const Fp8T*>(__builtin_assume_aligned(src, 8));
@@ -892,7 +889,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
892889
for (int j = vecEnd8 + laneId * kVec4; j < vecEnd4; j += warpSize * kVec4) {
893890
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
894891
#pragma unroll
895-
for (int n = 0; n < 2; n++) {
892+
for (int n = 0; n < AccumNum; n++) {
896893
const Fp8T* src = srcPtrs[n];
897894
if (src == nullptr) continue;
898895
Fp8x4T v;
@@ -914,7 +911,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
914911
for (int j = laneId * kVec4; j < vecEnd4; j += warpSize * kVec4) {
915912
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
916913
#pragma unroll
917-
for (int n = 0; n < 2; n++) {
914+
for (int n = 0; n < AccumNum; n++) {
918915
const Fp8T* src = srcPtrs[n];
919916
if (src == nullptr) continue;
920917
Fp8x4T v;
@@ -936,7 +933,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
936933
for (int j = scalarStart + laneId; j < hiddenDimSize; j += warpSize) {
937934
float sum = 0.0f;
938935
#pragma unroll
939-
for (int n = 0; n < 2; n++) {
936+
for (int n = 0; n < AccumNum; n++) {
940937
const Fp8T* src = srcPtrs[n];
941938
if (src == nullptr) continue;
942939
sum += float(src[j]);
@@ -945,9 +942,9 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
945942
}
946943
}
947944

948-
__forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
945+
__forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16Dynamic(
949946
hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const* __restrict__ srcPtrs,
950-
int nNodes, int laneId, int hiddenDimSize) {
947+
int accumNum, int laneId, int hiddenDimSize) {
951948
using Fp8T = CombineInternalFp8T;
952949
using Fp8x4T = CombineInternalFp8x4T;
953950

@@ -956,7 +953,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
956953

957954
bool canVec4 = true;
958955
#pragma unroll 4
959-
for (int n = 0; n < nNodes; n++) {
956+
for (int n = 0; n < accumNum; n++) {
960957
const Fp8T* src = srcPtrs[n];
961958
if (src == nullptr) continue;
962959
canVec4 &= ((reinterpret_cast<uintptr_t>(src) & 0x3) == 0);
@@ -966,7 +963,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
966963
for (int j = laneId * kVec4; j < vecEnd; j += warpSize * kVec4) {
967964
float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f};
968965
#pragma unroll 4
969-
for (int n = 0; n < nNodes; n++) {
966+
for (int n = 0; n < accumNum; n++) {
970967
const Fp8T* src = srcPtrs[n];
971968
if (src == nullptr) continue;
972969
Fp8x4T v;
@@ -988,7 +985,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
988985
for (int j = scalarStart + laneId; j < hiddenDimSize; j += warpSize) {
989986
float sum = 0.0f;
990987
#pragma unroll 4
991-
for (int n = 0; n < nNodes; n++) {
988+
for (int n = 0; n < accumNum; n++) {
992989
const Fp8T* src = srcPtrs[n];
993990
if (src == nullptr) continue;
994991
sum += float(src[j]);
@@ -1001,29 +998,37 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
1001998
#endif
1002999

10031000
template <typename T>
1004-
__forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16(
1005-
T* __restrict__ out, const CombineInternalFp8* const* __restrict__ srcPtrs, int nNodes,
1001+
__forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16(
1002+
T* __restrict__ out, const CombineInternalFp8* const* __restrict__ srcPtrs, int accumNum,
10061003
int laneId, int hiddenDimSize) {
10071004
#if defined(MORI_FP8_TYPE_OCP_ENABLED) || defined(MORI_FP8_TYPE_FNUZ_ENABLED)
10081005
if constexpr (std::is_same_v<T, hip_bfloat16>) {
1009-
if (nNodes == 2) {
1010-
detail::SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
1011-
reinterpret_cast<hip_bfloat16*>(out),
1012-
reinterpret_cast<const detail::CombineInternalFp8T* const*>(srcPtrs), laneId,
1013-
hiddenDimSize);
1014-
} else {
1015-
detail::SumCombineInternalFp8AcrossNodesToBf16Dynamic(
1016-
reinterpret_cast<hip_bfloat16*>(out),
1017-
reinterpret_cast<const detail::CombineInternalFp8T* const*>(srcPtrs), nNodes, laneId,
1018-
hiddenDimSize);
1006+
switch (accumNum) {
1007+
case 2:
1008+
detail::WarpAccumCombineInternalFp8ToBf16Fixed<2>(
1009+
reinterpret_cast<hip_bfloat16*>(out),
1010+
reinterpret_cast<const detail::CombineInternalFp8T* const*>(srcPtrs), laneId,
1011+
hiddenDimSize);
1012+
break;
1013+
case 8:
1014+
detail::WarpAccumCombineInternalFp8ToBf16Fixed<8>(
1015+
reinterpret_cast<hip_bfloat16*>(out),
1016+
reinterpret_cast<const detail::CombineInternalFp8T* const*>(srcPtrs), laneId,
1017+
hiddenDimSize);
1018+
break;
1019+
default:
1020+
detail::WarpAccumCombineInternalFp8ToBf16Dynamic(
1021+
reinterpret_cast<hip_bfloat16*>(out),
1022+
reinterpret_cast<const detail::CombineInternalFp8T* const*>(srcPtrs), accumNum, laneId,
1023+
hiddenDimSize);
1024+
break;
10191025
}
10201026
}
1027+
// Note: when T != hip_bfloat16, this function is a no-op.
1028+
// Callers should guard with if constexpr or ensure T is hip_bfloat16.
10211029
#else
1022-
(void)out;
1023-
(void)srcPtrs;
1024-
(void)nNodes;
1025-
(void)laneId;
1026-
(void)hiddenDimSize;
1030+
static_assert(!sizeof(T*), "WarpAccumCombineInternalFp8ToBf16 requires FP8 type support "
1031+
"(MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)");
10271032
#endif
10281033
}
10291034

src/application/transport/rdma/providers/bnxt/bnxt.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,10 @@ BnxtQpContainer::BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig&
286286
}
287287

288288
// Register atomic ibuf as independent memory region
289-
atomicIbufMr = ibv_reg_mr(pd, atomicIbufAddr, atomicIbufSize,
290-
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE |
291-
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC);
289+
int atomicIbufAccessFlag =
290+
MaybeAddRelaxedOrderingFlag(IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE |
291+
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC);
292+
atomicIbufMr = ibv_reg_mr(pd, atomicIbufAddr, atomicIbufSize, atomicIbufAccessFlag);
292293
assert(atomicIbufMr);
293294

294295
MORI_APP_TRACE(

src/application/transport/rdma/providers/ibverbs/ibverbs.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,21 @@ void IBVerbsDeviceContext::ConnectEndpoint(const RdmaEndpointHandle& local,
133133
attr.min_rnr_timer = 12;
134134
attr.ah_attr.src_path_bits = 0;
135135
attr.ah_attr.port_num = local.portId;
136-
attr.ah_attr.sl = ReadRdmaServiceLevelEnv().value_or(0);
137-
std::optional<uint8_t> tc = ReadRdmaTrafficClassEnv();
138-
if (tc.has_value()) {
139-
attr.ah_attr.grh.traffic_class = tc.value();
136+
std::optional<uint8_t> sl = ReadIoServiceLevelEnv();
137+
if (!sl.has_value()) {
138+
sl = ReadRdmaServiceLevelEnv();
139+
}
140+
attr.ah_attr.sl = sl.value_or(0);
141+
142+
bool disableIoTc = ReadIoTrafficClassDisableEnv();
143+
if (!disableIoTc) {
144+
std::optional<uint8_t> tc = ReadIoTrafficClassEnv();
145+
if (!tc.has_value()) {
146+
tc = ReadRdmaTrafficClassEnv();
147+
}
148+
if (tc.has_value()) {
149+
attr.ah_attr.grh.traffic_class = tc.value();
150+
}
140151
}
141152
MORI_APP_INFO("ibverbs attr.ah_attr.sl:{} attr.ah_attr.grh.traffic_class:{}", attr.ah_attr.sl,
142153
attr.ah_attr.grh.traffic_class);

0 commit comments

Comments
 (0)