@@ -796,11 +796,11 @@ __forceinline__ __device__ void WarpCastBf16ToCombineInternalFp8(
796796 }
797797 }
798798 }
799+ // Note: when T != hip_bfloat16, this function is a no-op.
800+ // Callers should guard with if constexpr or ensure T is hip_bfloat16.
799801#else
800- (void )dst;
801- (void )src;
802- (void )hiddenDim;
803- (void )laneId;
802+ static_assert (!sizeof (T*), " WarpCastBf16ToCombineInternalFp8 requires FP8 type support "
803+ " (MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)" );
804804#endif
805805}
806806
@@ -809,31 +809,28 @@ namespace detail {
809809using CombineInternalFp8T = CombineInternalFp8;
810810using CombineInternalFp8x4T = CombineInternalFp8x4;
811811
812- template <int NNodes>
813- __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed (
814- hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const * __restrict__ srcPtrs,
815- int laneId, int hiddenDimSize);
816-
817- template <>
818- __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2 >(
812+ template <int AccumNum>
813+ __forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16Fixed (
819814 hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const * __restrict__ srcPtrs,
820815 int laneId, int hiddenDimSize) {
816+ static_assert (AccumNum > 0 , " AccumNum must be positive" );
817+
821818 using Fp8T = CombineInternalFp8T;
822819 using Fp8x4T = CombineInternalFp8x4T;
823820 constexpr int kVec8 = 8 ;
824821 constexpr int kVec4 = 4 ;
825822
826- const Fp8T* src0 = srcPtrs[0 ];
827- const Fp8T* src1 = srcPtrs[1 ];
828-
829823 const uintptr_t outAddr = reinterpret_cast <uintptr_t >(out);
830- const uintptr_t src0Addr = reinterpret_cast <uintptr_t >(src0);
831- const uintptr_t src1Addr = reinterpret_cast <uintptr_t >(src1);
832-
833- const bool canVec8 = ((outAddr & 0x7 ) == 0 ) && ((src0 == nullptr ) || ((src0Addr & 0x7 ) == 0 )) &&
834- ((src1 == nullptr ) || ((src1Addr & 0x7 ) == 0 ));
835- const bool canVec4 = ((src0 == nullptr ) || ((src0Addr & 0x3 ) == 0 )) &&
836- ((src1 == nullptr ) || ((src1Addr & 0x3 ) == 0 ));
824+ bool canVec8 = ((outAddr & 0x7 ) == 0 );
825+ bool canVec4 = true ;
826+ #pragma unroll
827+ for (int n = 0 ; n < AccumNum; n++) {
828+ const Fp8T* src = srcPtrs[n];
829+ if (src == nullptr ) continue ;
830+ const uintptr_t srcAddr = reinterpret_cast <uintptr_t >(src);
831+ canVec8 &= ((srcAddr & 0x7 ) == 0 );
832+ canVec4 &= ((srcAddr & 0x3 ) == 0 );
833+ }
837834
838835 const int vecEnd8 = (hiddenDimSize / kVec8 ) * kVec8 ;
839836 const int vecEnd4 = (hiddenDimSize / kVec4 ) * kVec4 ;
@@ -846,7 +843,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
846843 float4 sumLo = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
847844 float4 sumHi = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
848845#pragma unroll
849- for (int n = 0 ; n < 2 ; n++) {
846+ for (int n = 0 ; n < AccumNum ; n++) {
850847 const Fp8T* src = srcPtrs[n];
851848 if (src == nullptr ) continue ;
852849 const auto * srcAligned = static_cast <const Fp8T*>(__builtin_assume_aligned (src, 8 ));
@@ -892,7 +889,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
892889 for (int j = vecEnd8 + laneId * kVec4 ; j < vecEnd4; j += warpSize * kVec4 ) {
893890 float4 sum4 = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
894891#pragma unroll
895- for (int n = 0 ; n < 2 ; n++) {
892+ for (int n = 0 ; n < AccumNum ; n++) {
896893 const Fp8T* src = srcPtrs[n];
897894 if (src == nullptr ) continue ;
898895 Fp8x4T v;
@@ -914,7 +911,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
914911 for (int j = laneId * kVec4 ; j < vecEnd4; j += warpSize * kVec4 ) {
915912 float4 sum4 = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
916913#pragma unroll
917- for (int n = 0 ; n < 2 ; n++) {
914+ for (int n = 0 ; n < AccumNum ; n++) {
918915 const Fp8T* src = srcPtrs[n];
919916 if (src == nullptr ) continue ;
920917 Fp8x4T v;
@@ -936,7 +933,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
936933 for (int j = scalarStart + laneId; j < hiddenDimSize; j += warpSize) {
937934 float sum = 0 .0f ;
938935#pragma unroll
939- for (int n = 0 ; n < 2 ; n++) {
936+ for (int n = 0 ; n < AccumNum ; n++) {
940937 const Fp8T* src = srcPtrs[n];
941938 if (src == nullptr ) continue ;
942939 sum += float (src[j]);
@@ -945,9 +942,9 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Fixed<2>(
945942 }
946943}
947944
948- __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic (
945+ __forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16Dynamic (
949946 hip_bfloat16* __restrict__ out, const CombineInternalFp8T* const * __restrict__ srcPtrs,
950- int nNodes , int laneId, int hiddenDimSize) {
947+ int accumNum , int laneId, int hiddenDimSize) {
951948 using Fp8T = CombineInternalFp8T;
952949 using Fp8x4T = CombineInternalFp8x4T;
953950
@@ -956,7 +953,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
956953
957954 bool canVec4 = true ;
958955#pragma unroll 4
959- for (int n = 0 ; n < nNodes ; n++) {
956+ for (int n = 0 ; n < accumNum ; n++) {
960957 const Fp8T* src = srcPtrs[n];
961958 if (src == nullptr ) continue ;
962959 canVec4 &= ((reinterpret_cast <uintptr_t >(src) & 0x3 ) == 0 );
@@ -966,7 +963,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
966963 for (int j = laneId * kVec4 ; j < vecEnd; j += warpSize * kVec4 ) {
967964 float4 sum4 = {0 .0f , 0 .0f , 0 .0f , 0 .0f };
968965#pragma unroll 4
969- for (int n = 0 ; n < nNodes ; n++) {
966+ for (int n = 0 ; n < accumNum ; n++) {
970967 const Fp8T* src = srcPtrs[n];
971968 if (src == nullptr ) continue ;
972969 Fp8x4T v;
@@ -988,7 +985,7 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
988985 for (int j = scalarStart + laneId; j < hiddenDimSize; j += warpSize) {
989986 float sum = 0 .0f ;
990987#pragma unroll 4
991- for (int n = 0 ; n < nNodes ; n++) {
988+ for (int n = 0 ; n < accumNum ; n++) {
992989 const Fp8T* src = srcPtrs[n];
993990 if (src == nullptr ) continue ;
994991 sum += float (src[j]);
@@ -1001,29 +998,37 @@ __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16Dynamic(
1001998#endif
1002999
10031000template <typename T>
1004- __forceinline__ __device__ void SumCombineInternalFp8AcrossNodesToBf16 (
1005- T* __restrict__ out, const CombineInternalFp8* const * __restrict__ srcPtrs, int nNodes ,
1001+ __forceinline__ __device__ void WarpAccumCombineInternalFp8ToBf16 (
1002+ T* __restrict__ out, const CombineInternalFp8* const * __restrict__ srcPtrs, int accumNum ,
10061003 int laneId, int hiddenDimSize) {
10071004#if defined(MORI_FP8_TYPE_OCP_ENABLED) || defined(MORI_FP8_TYPE_FNUZ_ENABLED)
10081005 if constexpr (std::is_same_v<T, hip_bfloat16>) {
1009- if (nNodes == 2 ) {
1010- detail::SumCombineInternalFp8AcrossNodesToBf16Fixed<2 >(
1011- reinterpret_cast <hip_bfloat16*>(out),
1012- reinterpret_cast <const detail::CombineInternalFp8T* const *>(srcPtrs), laneId,
1013- hiddenDimSize);
1014- } else {
1015- detail::SumCombineInternalFp8AcrossNodesToBf16Dynamic (
1016- reinterpret_cast <hip_bfloat16*>(out),
1017- reinterpret_cast <const detail::CombineInternalFp8T* const *>(srcPtrs), nNodes, laneId,
1018- hiddenDimSize);
1006+ switch (accumNum) {
1007+ case 2 :
1008+ detail::WarpAccumCombineInternalFp8ToBf16Fixed<2 >(
1009+ reinterpret_cast <hip_bfloat16*>(out),
1010+ reinterpret_cast <const detail::CombineInternalFp8T* const *>(srcPtrs), laneId,
1011+ hiddenDimSize);
1012+ break ;
1013+ case 8 :
1014+ detail::WarpAccumCombineInternalFp8ToBf16Fixed<8 >(
1015+ reinterpret_cast <hip_bfloat16*>(out),
1016+ reinterpret_cast <const detail::CombineInternalFp8T* const *>(srcPtrs), laneId,
1017+ hiddenDimSize);
1018+ break ;
1019+ default :
1020+ detail::WarpAccumCombineInternalFp8ToBf16Dynamic (
1021+ reinterpret_cast <hip_bfloat16*>(out),
1022+ reinterpret_cast <const detail::CombineInternalFp8T* const *>(srcPtrs), accumNum, laneId,
1023+ hiddenDimSize);
1024+ break ;
10191025 }
10201026 }
1027+ // Note: when T != hip_bfloat16, this function is a no-op.
1028+ // Callers should guard with if constexpr or ensure T is hip_bfloat16.
10211029#else
1022- (void )out;
1023- (void )srcPtrs;
1024- (void )nNodes;
1025- (void )laneId;
1026- (void )hiddenDimSize;
1030+ static_assert (!sizeof (T*), " WarpAccumCombineInternalFp8ToBf16 requires FP8 type support "
1031+ " (MORI_FP8_TYPE_OCP_ENABLED or MORI_FP8_TYPE_FNUZ_ENABLED)" );
10271032#endif
10281033}
10291034
0 commit comments