From d128f3f77e549f38006d5142e38998aa606588f5 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 25 Mar 2026 10:23:03 +0800 Subject: [PATCH 1/3] Reapply "Support premul_sum for xccl (#1947)" (#2913) This reverts commit 770ea427fb6808ce04294ee602baba91e3d11fa6. --- src/xccl/xccl.cpp | 18 +++-- src/xccl/xccl.h | 188 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 190 insertions(+), 16 deletions(-) diff --git a/src/xccl/xccl.cpp b/src/xccl/xccl.cpp index 7702b4ca1a..bbefee3581 100644 --- a/src/xccl/xccl.cpp +++ b/src/xccl/xccl.cpp @@ -38,7 +38,8 @@ void onecclAllReduce( at::xpu::XPUStream& stream) { if (isCCLV2EnabledCached()) { auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm); onecclAllReduce( input.data_ptr(), output.data_ptr(), @@ -49,7 +50,8 @@ void onecclAllReduce( &stream.queue()); } else { auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -72,7 +74,8 @@ void onecclReduce( at::xpu::XPUStream& stream) { if (isCCLV2EnabledCached()) { auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm); onecclReduce( input.data_ptr(), output.data_ptr(), @@ -84,7 +87,8 @@ void onecclReduce( &stream.queue()); } else { auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm); ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -138,7 +142,8 @@ void onecclReduceScatter( at::xpu::XPUStream& stream) { if (isCCLV2EnabledCached()) { auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm); onecclReduceScatter( input.data_ptr(), output.data_ptr(), @@ -149,7 +154,8 @@ void onecclReduceScatter( &stream.queue()); } else { auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input); + auto xcclReduceOp = + getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm); ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), diff --git a/src/xccl/xccl.h b/src/xccl/xccl.h index 889fa1f282..84eb8ead01 100644 --- a/src/xccl/xccl.h +++ b/src/xccl/xccl.h @@ -36,6 +36,12 @@ #define XCCL_HAS_AVG 1 #endif // oneCCL version >= 2021.15 +#if defined(CCL_MAJOR_VERSION) && \ + ((CCL_MAJOR_VERSION > 2021) || \ + (CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 17)) +#define ENABLE_XCCL_PREMUL_SUM_SUPPORT +#endif // oneCCL version >= 2021.17 + inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: @@ -174,6 +180,117 @@ inline const std::map xcclDatatypesV1 = { namespace { +struct XCCLTraitsV1 { + using OpType = ccl::reduction; + using CommType = const ccl::communicator*; + +#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT) + static void destroyOp(OpType op, CommType comm) { + ccl::reduction_destroy(op, *comm); + } +#endif +}; + +struct XCCLTraitsV2 { + using OpType = onecclRedOp_t; + using CommType = onecclComm_t; + +#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT) + static void destroyOp(OpType op, CommType comm) { + onecclRedOpDestroy(op, comm); + } +#endif +}; + +template +struct xcclRedOpRAII { + using OpType = typename Traits::OpType; + using CommType = typename Traits::CommType; + + xcclRedOpRAII() = default; + xcclRedOpRAII(OpType op) : op_(op) {} + xcclRedOpRAII(OpType op, CommType comm) + : op_(op), comm_(comm), premul_sum_(true) {} + + xcclRedOpRAII(const xcclRedOpRAII&) = delete; + xcclRedOpRAII& operator=(const xcclRedOpRAII&) = delete; + + xcclRedOpRAII(xcclRedOpRAII&& tmp) noexcept : xcclRedOpRAII() { + std::swap(tmp.op_, this->op_); + std::swap(tmp.comm_, this->comm_); + std::swap(tmp.premul_sum_, this->premul_sum_); + } + +#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT) + ~xcclRedOpRAII() { + if (premul_sum_ && comm_) { + Traits::destroyOp(op_, comm_); + } + } +#endif + + operator OpType() const { + return op_; + } + + OpType op_{}; + CommType comm_{}; + bool premul_sum_ = false; +}; + +using xcclRedOpRAIIV1 = xcclRedOpRAII; +using xcclRedOpRAIIV2 = xcclRedOpRAII; + +#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT +template +inline xcclRedOpRAIIV1 unpackPreMulSumV1( + const ReduceOp& reduceOp, + const ccl::communicator& comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + ccl::reduction preMulSum{}; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor + ? ccl::scalar_residence_type::scalar_device + : ccl::scalar_residence_type::scalar_host_immediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + ccl::reduction_create_pre_mul_sum( + &preMulSum, + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return xcclRedOpRAIIV1(preMulSum, &comm); +} + +template +inline xcclRedOpRAIIV2 unpackPreMulSumV2( + const ReduceOp& reduceOp, + onecclComm_t comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + onecclRedOp_t preMulSum{}; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor + ? onecclScalarResidence_t::onecclScalarDevice + : onecclScalarResidence_t::onecclScalarHostImmediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + onecclRedOpCreatePreMulSum( + &preMulSum, + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return xcclRedOpRAIIV2(preMulSum, comm); +} +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + inline ccl::datatype getXcclDataTypeV1( at::ScalarType type, bool is_reduction_op = false) { @@ -191,7 +308,6 @@ inline ccl::datatype getXcclDataTypeV1( return it->second; } -// V2 specific function to avoid variant overhead inline onecclDataType_t getXcclDataTypeV2( at::ScalarType type, bool is_reduction_op = false) { @@ -209,13 +325,15 @@ inline onecclDataType_t getXcclDataTypeV2( return it->second; } -inline ccl::reduction getXcclReduceOpV1( +inline xcclRedOpRAIIV1 getXcclReduceOpV1( const ReduceOp& reduceOp, - at::Tensor& input) { + at::Tensor& input, + const ccl::datatype& dataType, + const ccl::communicator& comm) { try { if (input.scalar_type() == at::kBool) { if (reduceOp == ReduceOp::SUM) { - return ccl::reduction::max; + return xcclRedOpRAIIV1(ccl::reduction::max); } #ifdef XCCL_HAS_AVG if (reduceOp == ReduceOp::AVG) { @@ -227,10 +345,34 @@ inline ccl::reduction getXcclReduceOpV1( #if !defined(XCCL_HAS_AVG) if (reduceOp == ReduceOp::AVG) { LOG(INFO) << "[Reduce] Use sum emulation for avg"; - return ccl::reduction::sum; + return xcclRedOpRAIIV1(ccl::reduction::sum); } #endif - return xcclOpsV1.at(reduceOp); + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case ccl::datatype::float16: + return unpackPreMulSumV1( + reduceOp, comm); + case ccl::datatype::float32: + return unpackPreMulSumV1( + reduceOp, comm); + case ccl::datatype::bfloat16: + return unpackPreMulSumV1( + reduceOp, comm); + case ccl::datatype::float64: + return unpackPreMulSumV1( + reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, + "PreMulSum Data type must be half, float, bfloat16 or double"); + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17"); +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + } + return xcclRedOpRAIIV1(xcclOpsV1.at(reduceOp)); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, @@ -238,20 +380,46 @@ inline ccl::reduction getXcclReduceOpV1( } } -inline onecclRedOp_t getXcclReduceOpV2( +inline xcclRedOpRAIIV2 getXcclReduceOpV2( const ReduceOp& reduceOp, - at::Tensor& input) { + at::Tensor& input, + const onecclDataType_t& dataType, + onecclComm_t comm) { try { if (input.scalar_type() == at::kBool) { if (reduceOp == ReduceOp::SUM) { - return onecclRedOp_t::onecclMax; + return xcclRedOpRAIIV2(onecclRedOp_t::onecclMax); } if (reduceOp == ReduceOp::AVG) { C10_THROW_ERROR( TypeError, "Cannot use ReduceOp.AVG with boolean inputs"); } } - return xcclOpsV2.at(reduceOp); + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case onecclDataType_t::onecclFloat16: + return unpackPreMulSumV2( + reduceOp, comm); + case onecclDataType_t::onecclFloat32: + return unpackPreMulSumV2( + reduceOp, comm); + case onecclDataType_t::onecclBfloat16: + return unpackPreMulSumV2( + reduceOp, comm); + case onecclDataType_t::onecclFloat64: + return unpackPreMulSumV2( + reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, + "PreMulSum Data type must be half, float, bfloat16 or double"); + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17"); +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + } + return xcclRedOpRAIIV2(xcclOpsV2.at(reduceOp)); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, From c508c6a2d1c30f677fc5cb1ed9ae1b50cac6aa73 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 25 Mar 2026 10:25:05 +0800 Subject: [PATCH 2/3] change api --- src/xccl/xccl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xccl/xccl.h b/src/xccl/xccl.h index 84eb8ead01..fca9b27b06 100644 --- a/src/xccl/xccl.h +++ b/src/xccl/xccl.h @@ -247,7 +247,7 @@ inline xcclRedOpRAIIV1 unpackPreMulSumV1( const ReduceOp& reduceOp, const ccl::communicator& comm) { const auto* preMulSupplement = - reinterpret_cast(reduceOp.supplement_.get()); + reinterpret_cast(reduceOp.supplement_.get()); ccl::reduction preMulSum{}; bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor @@ -271,7 +271,7 @@ inline xcclRedOpRAIIV2 unpackPreMulSumV2( const ReduceOp& reduceOp, onecclComm_t comm) { const auto* preMulSupplement = - reinterpret_cast(reduceOp.supplement_.get()); + reinterpret_cast(reduceOp.supplement_.get()); onecclRedOp_t preMulSum{}; bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor From a9bfefc212c7a475da7626a7aa6bf8c24c77d20a Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 25 Mar 2026 10:28:45 +0800 Subject: [PATCH 3/3] change version --- src/xccl/xccl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xccl/xccl.h b/src/xccl/xccl.h index fca9b27b06..3a68ceec3d 100644 --- a/src/xccl/xccl.h +++ b/src/xccl/xccl.h @@ -37,10 +37,10 @@ #endif // oneCCL version >= 2021.15 #if defined(CCL_MAJOR_VERSION) && \ - ((CCL_MAJOR_VERSION > 2021) || \ - (CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 17)) + ((CCL_MAJOR_VERSION > 2022) || \ + (CCL_MAJOR_VERSION == 2022) && (CCL_MINOR_VERSION >= 0)) #define ENABLE_XCCL_PREMUL_SUM_SUPPORT -#endif // oneCCL version >= 2021.17 +#endif // oneCCL version >= 2022.0 inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) {