Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/xccl/xccl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ void onecclAllReduce(
at::xpu::XPUStream& stream) {
if (isCCLV2EnabledCached()) {
auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm);
onecclAllReduce(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -49,7 +50,8 @@ void onecclAllReduce(
&stream.queue());
} else {
auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm);
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -72,7 +74,8 @@ void onecclReduce(
at::xpu::XPUStream& stream) {
if (isCCLV2EnabledCached()) {
auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm);
onecclReduce(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -84,7 +87,8 @@ void onecclReduce(
&stream.queue());
} else {
auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm);
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -138,7 +142,8 @@ void onecclReduceScatter(
at::xpu::XPUStream& stream) {
if (isCCLV2EnabledCached()) {
auto xcclDataType = getXcclDataTypeV2(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV2(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV2(reduceOp, input, xcclDataType, comm.onecclComm);
onecclReduceScatter(
input.data_ptr(),
output.data_ptr(),
Expand All @@ -149,7 +154,8 @@ void onecclReduceScatter(
&stream.queue());
} else {
auto xcclDataType = getXcclDataTypeV1(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOpV1(reduceOp, input);
auto xcclReduceOp =
getXcclReduceOpV1(reduceOp, input, xcclDataType, *comm.cclComm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down
188 changes: 178 additions & 10 deletions src/xccl/xccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
#define XCCL_HAS_AVG 1
#endif // oneCCL version >= 2021.15

#if defined(CCL_MAJOR_VERSION) && \
((CCL_MAJOR_VERSION > 2022) || \
(CCL_MAJOR_VERSION == 2022) && (CCL_MINOR_VERSION >= 0))
#define ENABLE_XCCL_PREMUL_SUM_SUPPORT
#endif // oneCCL version >= 2022.0

inline std::string reduceOpToString(c10d::ReduceOp op) {
switch (op) {
case c10d::ReduceOp::SUM:
Expand Down Expand Up @@ -174,6 +180,117 @@ inline const std::map<at::ScalarType, ccl::datatype> xcclDatatypesV1 = {

namespace {

struct XCCLTraitsV1 {
using OpType = ccl::reduction;
using CommType = const ccl::communicator*;

#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
static void destroyOp(OpType op, CommType comm) {
ccl::reduction_destroy(op, *comm);
}
#endif
};

struct XCCLTraitsV2 {
using OpType = onecclRedOp_t;
using CommType = onecclComm_t;

#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
static void destroyOp(OpType op, CommType comm) {
onecclRedOpDestroy(op, comm);
}
#endif
};

template <typename Traits>
struct xcclRedOpRAII {
using OpType = typename Traits::OpType;
using CommType = typename Traits::CommType;

xcclRedOpRAII() = default;
xcclRedOpRAII(OpType op) : op_(op) {}
xcclRedOpRAII(OpType op, CommType comm)
: op_(op), comm_(comm), premul_sum_(true) {}

xcclRedOpRAII(const xcclRedOpRAII&) = delete;
xcclRedOpRAII& operator=(const xcclRedOpRAII&) = delete;

xcclRedOpRAII(xcclRedOpRAII&& tmp) noexcept : xcclRedOpRAII() {
std::swap(tmp.op_, this->op_);
std::swap(tmp.comm_, this->comm_);
std::swap(tmp.premul_sum_, this->premul_sum_);
}

#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
~xcclRedOpRAII() {
if (premul_sum_ && comm_) {
Traits::destroyOp(op_, comm_);
}
}
#endif

operator OpType() const {
return op_;
}

OpType op_{};
CommType comm_{};
bool premul_sum_ = false;
};

using xcclRedOpRAIIV1 = xcclRedOpRAII<XCCLTraitsV1>;
using xcclRedOpRAIIV2 = xcclRedOpRAII<XCCLTraitsV2>;

#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
template <typename T, ccl::datatype dataType>
inline xcclRedOpRAIIV1 unpackPreMulSumV1(
const ReduceOp& reduceOp,
const ccl::communicator& comm) {
const auto* preMulSupplement =
reinterpret_cast<PreMulSumSupplement*>(reduceOp.supplement_.get());
ccl::reduction preMulSum{};
bool has_tensor = preMulSupplement->tensor_factor.defined();
auto residence = has_tensor
? ccl::scalar_residence_type::scalar_device
: ccl::scalar_residence_type::scalar_host_immediate;
const T* ptr_factor = has_tensor
? preMulSupplement->tensor_factor.const_data_ptr<T>()
: nullptr;
T scalar_factor = T(preMulSupplement->double_factor);
ccl::reduction_create_pre_mul_sum(
&preMulSum,
/*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
dataType,
residence,
comm);
return xcclRedOpRAIIV1(preMulSum, &comm);
}

template <typename T, onecclDataType_t dataType>
inline xcclRedOpRAIIV2 unpackPreMulSumV2(
const ReduceOp& reduceOp,
onecclComm_t comm) {
const auto* preMulSupplement =
reinterpret_cast<PreMulSumSupplement*>(reduceOp.supplement_.get());
onecclRedOp_t preMulSum{};
bool has_tensor = preMulSupplement->tensor_factor.defined();
auto residence = has_tensor
? onecclScalarResidence_t::onecclScalarDevice
: onecclScalarResidence_t::onecclScalarHostImmediate;
const T* ptr_factor = has_tensor
? preMulSupplement->tensor_factor.const_data_ptr<T>()
: nullptr;
T scalar_factor = T(preMulSupplement->double_factor);
onecclRedOpCreatePreMulSum(
&preMulSum,
/*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
dataType,
residence,
comm);
return xcclRedOpRAIIV2(preMulSum, comm);
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT

inline ccl::datatype getXcclDataTypeV1(
at::ScalarType type,
bool is_reduction_op = false) {
Expand All @@ -191,7 +308,6 @@ inline ccl::datatype getXcclDataTypeV1(
return it->second;
}

// V2 specific function to avoid variant overhead
inline onecclDataType_t getXcclDataTypeV2(
at::ScalarType type,
bool is_reduction_op = false) {
Expand All @@ -209,13 +325,15 @@ inline onecclDataType_t getXcclDataTypeV2(
return it->second;
}

inline ccl::reduction getXcclReduceOpV1(
inline xcclRedOpRAIIV1 getXcclReduceOpV1(
const ReduceOp& reduceOp,
at::Tensor& input) {
at::Tensor& input,
const ccl::datatype& dataType,
const ccl::communicator& comm) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
return ccl::reduction::max;
return xcclRedOpRAIIV1(ccl::reduction::max);
}
#ifdef XCCL_HAS_AVG
if (reduceOp == ReduceOp::AVG) {
Expand All @@ -227,31 +345,81 @@ inline ccl::reduction getXcclReduceOpV1(
#if !defined(XCCL_HAS_AVG)
if (reduceOp == ReduceOp::AVG) {
LOG(INFO) << "[Reduce] Use sum emulation for avg";
return ccl::reduction::sum;
return xcclRedOpRAIIV1(ccl::reduction::sum);
}
#endif
return xcclOpsV1.at(reduceOp);
if (reduceOp == ReduceOp::PREMUL_SUM) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
switch (dataType) {
case ccl::datatype::float16:
return unpackPreMulSumV1<at::Half, ccl::datatype::float16>(
reduceOp, comm);
case ccl::datatype::float32:
return unpackPreMulSumV1<float, ccl::datatype::float32>(
reduceOp, comm);
case ccl::datatype::bfloat16:
return unpackPreMulSumV1<float, ccl::datatype::bfloat16>(
reduceOp, comm);
case ccl::datatype::float64:
return unpackPreMulSumV1<double, ccl::datatype::float64>(
reduceOp, comm);
default:
C10_THROW_ERROR(
TypeError,
"PreMulSum Data type must be half, float, bfloat16 or double");
}
#else
C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17");
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}
return xcclRedOpRAIIV1(xcclOpsV1.at(reduceOp));
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
ValueError,
"Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL");
}
}

inline onecclRedOp_t getXcclReduceOpV2(
inline xcclRedOpRAIIV2 getXcclReduceOpV2(
const ReduceOp& reduceOp,
at::Tensor& input) {
at::Tensor& input,
const onecclDataType_t& dataType,
onecclComm_t comm) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
return onecclRedOp_t::onecclMax;
return xcclRedOpRAIIV2(onecclRedOp_t::onecclMax);
}
if (reduceOp == ReduceOp::AVG) {
C10_THROW_ERROR(
TypeError, "Cannot use ReduceOp.AVG with boolean inputs");
}
}
return xcclOpsV2.at(reduceOp);
if (reduceOp == ReduceOp::PREMUL_SUM) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
switch (dataType) {
case onecclDataType_t::onecclFloat16:
return unpackPreMulSumV2<at::Half, onecclDataType_t::onecclFloat16>(
reduceOp, comm);
case onecclDataType_t::onecclFloat32:
return unpackPreMulSumV2<float, onecclDataType_t::onecclFloat32>(
reduceOp, comm);
case onecclDataType_t::onecclBfloat16:
return unpackPreMulSumV2<float, onecclDataType_t::onecclBfloat16>(
reduceOp, comm);
case onecclDataType_t::onecclFloat64:
return unpackPreMulSumV2<double, onecclDataType_t::onecclFloat64>(
reduceOp, comm);
default:
C10_THROW_ERROR(
TypeError,
"PreMulSum Data type must be half, float, bfloat16 or double");
}
#else
C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17");
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}
return xcclRedOpRAIIV2(xcclOpsV2.at(reduceOp));
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
ValueError,
Expand Down
Loading