diff --git a/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp new file mode 100644 index 00000000000..8d6b219b148 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/fused_apply_rotary_emb.cpp @@ -0,0 +1,144 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct FusedApplyRotaryEmbCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + std::string x_layout{}; + std::string output_layout{}; + std::string mode{}; + int64_t tensor_index{}; + int64_t k_size{}; + float base = 0.0f; + int64_t rotary_size{}; +}; + +class FusedApplyRotaryEmb : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const FusedApplyRotaryEmbCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe FusedApplyRotaryEmb::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); +} + +Maybe FusedApplyRotaryEmb::Capture(FusedApplyRotaryEmbCaptureState* ctx, + const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const { + CHECK_OR_RETURN((inputs.size() >= 1) && (inputs.size() <= 4)) + << Error::RuntimeError() << "the inputs size of fusedapplyrotaryembgrad\ + should between 1 and 4"; + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->SaveTensorForBackward(inputs.at(0)); + if (inputs.size() == 2) // position_ids + ctx->SaveTensorForBackward(inputs.at(1)); + + if (inputs.size() == 3) { // cos, sin + ctx->SaveTensorForBackward(inputs.at(1)); + ctx->SaveTensorForBackward(inputs.at(2)); + } + + if (inputs.size() == 4) { // cos, sin, position_ids; + ctx->SaveTensorForBackward(inputs.at(1)); + ctx->SaveTensorForBackward(inputs.at(2)); + ctx->SaveTensorForBackward(inputs.at(3)); + } + + ctx->x_layout = JUST(composed_attrs.GetAttr("x_layout")); + ctx->output_layout = JUST(composed_attrs.GetAttr("output_layout")); + ctx->mode = JUST(composed_attrs.GetAttr("mode")); + ctx->tensor_index = JUST(composed_attrs.GetAttr("tensor_index")); + ctx->k_size = JUST(composed_attrs.GetAttr("k_size")); + ctx->base = JUST(composed_attrs.GetAttr("base")); + ctx->rotary_size = JUST(composed_attrs.GetAttr("rotary_size")); + + return Maybe::Ok(); +} + +Maybe FusedApplyRotaryEmb::Apply(const FusedApplyRotaryEmbCaptureState* ctx, + const TensorTuple& out_grads, TensorTuple* in_grads) const { + CHECK_EQ_OR_RETURN(out_grads.size(), 1) + << Error::RuntimeError() << "fusedapplyrotaryembgrad outgrad size should be 1"; + const auto& saved_tensors = ctx->SavedTensors(); + + CHECK_OR_RETURN((saved_tensors.size() >= 1) && (saved_tensors.size() <= 4)) + << Error::RuntimeError() << "the saved_tensors of fusedapplyrotaryembgrad\ + should between 1 and 4"; + + if (ctx->requires_grad) { + if (saved_tensors.size() == 1) { // x + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, NullOpt /*position_ids*/, + ctx->x_layout, ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, + ctx->rotary_size)); + } + + if (saved_tensors.size() == 2) { // x, position_ids + const auto& x = ctx->SavedTensors().at(0); + const auto& position_ids = ctx->SavedTensors().at(1); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), NullOpt /*cos*/, NullOpt /*sin*/, position_ids, ctx->x_layout, + ctx->output_layout, ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, + ctx->rotary_size)); + } + + if (saved_tensors.size() == 3) { // x, cos, sin, position_ids + const auto& x = ctx->SavedTensors().at(0); + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), cos, sin, NullOpt /*position_ids*/, ctx->x_layout, ctx->output_layout, + ctx->mode, ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + + if (saved_tensors.size() == 4) { + const auto& x = ctx->SavedTensors().at(0); + const auto& cos = ctx->SavedTensors().at(1); + const auto& sin = ctx->SavedTensors().at(2); + const auto& position_ids = ctx->SavedTensors().at(3); + in_grads->at(0) = JUST(functional::FusedApplyRotaryEmbGrad( + x, out_grads.at(0), cos, sin, position_ids, ctx->x_layout, ctx->output_layout, ctx->mode, + ctx->tensor_index, ctx->k_size, ctx->base, ctx->rotary_size)); + } + } + + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("fused_apply_rotary_emb", FusedApplyRotaryEmb); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 8b05bf73a44..391eb3e8522 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1279,6 +1279,10 @@ signature: 'Tensor (Tensor x, *, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmb' bind_python: True +- name: "fused_apply_rotary_emb_grad" + signature: 'Tensor (Tensor x, Tensor dy, Tensor cos=None, Tensor sin=None, Tensor position_ids=None, String x_layout="BHMK", String output_layout=None, String mode="plane", Int64 tensor_index=None, Int64 k_size=None, Float base=1e4, Int64 rotary_size=None) => FusedApplyRotaryEmbGrad' + bind_python: False + - name: "fused_relu_dropout_grad" signature: "Tensor (Tensor dy, Tensor mask, Float scale) => FusedReluDropoutGrad" bind_python: False diff --git a/oneflow/core/functional/impl/fused_attention_functor.cpp b/oneflow/core/functional/impl/fused_attention_functor.cpp index 413caf69e5f..4e57253c229 100644 --- a/oneflow/core/functional/impl/fused_attention_functor.cpp +++ b/oneflow/core/functional/impl/fused_attention_functor.cpp @@ -733,6 +733,131 @@ class FusedApplyRotaryEmbFunctor { std::shared_ptr op_without_position_sinuous_; }; +class FusedApplyRotaryEmbGradFunctor { + public: + FusedApplyRotaryEmbGradFunctor() { + op_with_position_sinuous_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("cos") + .Input("sin") + .Input("position_ids") + .Output("dx") + .Build()); + op_with_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("position_ids") + .Output("dx") + .Build()); + op_without_position_ = CHECK_JUST(one::OpBuilder("fused_apply_rotary_emb_grad") + .Input("x") + .Input("dy") + .Input("cos") + .Input("sin") + .Output("dx") + .Build()); + op_without_position_sinuous_ = CHECK_JUST( + one::OpBuilder("fused_apply_rotary_emb_grad").Input("x").Input("dy").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& dy, const Optional& cos, + const Optional& sin, + const Optional& position_ids, const std::string& x_layout, + const Optional& output_layout, const std::string& mode, + const Optional& tensor_index, const Optional& k_size, + const float base, const Optional& rotary_size) const { + int64_t b = 0, m = 0, h = 0, k = 0; + + if (tensor_index) { + CHECK_OR_RETURN((JUST(tensor_index) >= 0) && (JUST(tensor_index) <= 2)) + << "tensor_index should be set between [0, 2]"; + } + CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) + << "mode should be \"intervel\" or \"plane\""; + + JUST(ParseDims("x", *x->shape(), x_layout, Optional(), k_size, &b, &m, &h, &k)); + + if (k_size) { + CHECK_EQ_OR_RETURN(JUST(k_size), k) + << "k_size if given should be equal to K of cos, sin and x."; + } + if (rotary_size) { + CHECK_LE_OR_RETURN(JUST(rotary_size), k) << "rotary_size should be no more than k."; + } + + int64_t rotary_emd_dim = 1; + + if (position_ids) { + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->NumAxes(), 3) + << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(0), b) + << "1st dim of position_ids should be equal to B."; + CHECK_EQ_OR_RETURN(JUST(position_ids)->shape()->At(2), m) + << "3rd dim of position_ids should be equal to M."; + rotary_emd_dim = JUST(position_ids)->shape()->At(1); + CHECK_OR_RETURN(rotary_emd_dim == 1 || rotary_emd_dim == 2) + << "2nd dim of position_ids should be 1 or 2."; + } + + const int64_t actual_rotary_size = rotary_size.value_or(k) / rotary_emd_dim; + CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) + << "k ,or rotary_size if given, should be a multiple of 2 * rotary_encoding_dim."; + + if (cos && sin) { + CHECK_EQ_OR_RETURN(JUST(cos)->shape()->NumAxes(), 2) + << "The number of dimensions of cos should be equal to 2."; + CHECK_OR_RETURN(JUST(cos)->shape() == JUST(sin)->shape()) + << "Each dimension of cos & sin should be the same."; + CHECK_EQ_OR_RETURN(JUST(cos)->shape()->At(1), actual_rotary_size) + << "The 1st dimension of cos & sin should equal to rotary_size // " + "rotary_embedding_dimension."; + } else if (!cos && !sin) { + // do nothing + } else { + UNIMPLEMENTED_THEN_RETURN() << "cos & sin should both be given or not given."; + } + + if (!position_ids) { + if (cos && sin) { + CHECK_GE_OR_RETURN(JUST(cos)->shape()->At(0), m) + << "M of cos & sin should be to no less than " + "M of x when position_ids is not " + "given."; // K of cos & sin is checked + // inside ParseDims + } + } + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("x_layout", "output_layout", "mode", + "tensor_index", "k_size", "base", "rotary_size"); + attrs.SetAllAttrs(x_layout, output_layout.value_or(x_layout), mode, tensor_index.value_or(0), + k_size.value_or(k), base, rotary_size.value_or(k)); + + if (position_ids) { + if (cos && sin) { + return OpInterpUtil::Dispatch( + *op_with_position_sinuous_, {x, dy, JUST(cos), JUST(sin), JUST(position_ids)}, attrs); + } else { + return OpInterpUtil::Dispatch(*op_with_position_, {x, dy, JUST(position_ids)}, + attrs); + } + } else { + if (cos && sin) { + return OpInterpUtil::Dispatch(*op_without_position_, {x, dy, JUST(cos), JUST(sin)}, + attrs); + } else { + return OpInterpUtil::Dispatch(*op_without_position_sinuous_, {x, dy}, attrs); + } + } + } + + private: + std::shared_ptr op_with_position_; + std::shared_ptr op_with_position_sinuous_; + std::shared_ptr op_without_position_; + std::shared_ptr op_without_position_sinuous_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -741,6 +866,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { "FusedMultiHeadAttentionInferenceV2"); m.add_functor("FusedAttentionConcatPastKeyValue"); m.add_functor("FusedApplyRotaryEmb"); + m.add_functor("FusedApplyRotaryEmbGrad"); } } // namespace functional diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index b50a1ceceab..f2a91b4b032 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -3979,6 +3979,32 @@ def OneFlow_FusedApplyRotaryEmbOp : OneFlow_BaseOp<"fused_apply_rotary_emb", [No let has_data_type_infer_fn = 1; } +def OneFlow_FusedApplyRotaryEmbGradOp : OneFlow_BaseOp<"fused_apply_rotary_emb_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy, + Optional:$cos, + Optional:$sin, + Optional:$position_ids + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$x_layout, + DefaultValuedAttr:$output_layout, + DefaultValuedAttr:$mode, + DefaultValuedAttr:$tensor_index, + DefaultValuedAttr:$base, + DefaultValuedAttr:$k_size, + DefaultValuedAttr:$rotary_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_EmbeddingGradOp : OneFlow_BaseOp<"embedding_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index d432a1f30a8..4e4a4f396c6 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -24,6 +24,7 @@ limitations under the License. #include "cutlass/gemm/warp/mma.h" #include "kernel_forward.h" #include "oneflow/core/kernel/cuda_graph_support.h" +#include "oneflow/core/kernel/kernel_util.h" #include "trt_flash_attention/fmha.h" #include "trt_flash_attention/fmha_flash_attention.h" @@ -1017,6 +1018,7 @@ struct FusedApplyRotaryEmbParam { IndexType num_elements; const IndexType k; const IndexType x_offset; + const IndexType out_offset; IndexType ref_stride[num_dims]; // b, m, h, k IndexType out_stride[num_dims]; // ordered descendingly by stride @@ -1032,7 +1034,7 @@ struct FusedApplyRotaryEmbParam { const IndexType actual_rotary_size, const IndexType rotary_size, const IndexType rotate_stride, const IndexType num_elements, const IndexType k, const IndexType k0, const IndexType k1, - const IndexType x_offset) + const IndexType x_offset, const IndexType out_offset) : x(x), cos(cos), sin(sin), @@ -1047,7 +1049,8 @@ struct FusedApplyRotaryEmbParam { k(k), k0(k0), k1(k1), - x_offset(x_offset) {} + x_offset(x_offset), + out_offset(out_offset) {} }; template +__global__ void IntervalGradKernel( + FusedApplyRotaryEmbParam param) { + for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x; + packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) { + using LoadPack = cuda::elementwise::Packed; + IndexType offset = packed_offset * PackSize; + IndexType index[num_dims]; // b, m, h, k + + IndexType temp_offset = offset; + + for (int i = 0; i < num_dims - 1; i++) { + IndexType ref_stride = param.ref_stride[i]; + IndexType idx = temp_offset / ref_stride; + index[i] = idx; + temp_offset = temp_offset - idx * ref_stride; + } + index[num_dims - 1] = temp_offset; + + IndexType x_offset = param.x_offset; + IndexType out_offset = param.out_offset; +#pragma unroll + for (int i = 0; i < num_dims; i++) { + x_offset = x_offset + param.x_stride[i] * index[i]; + out_offset = out_offset + param.out_stride[i] * index[i]; + } + const LoadPack x_vec = *reinterpret_cast(param.x + x_offset); + + const IndexType k_index = index[num_dims - 1]; + if (k_index < param.rotary_size) { + const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; + const IndexType b_index = index[0], m_index = index[1]; + const IndexType position_id_offset = b_index * param.position_b_stride + + position_rotate_index * param.position_rotate_stride + + m_index; + + const PositionType position = + param.position_ids ? param.position_ids[position_id_offset] : m_index; + const IndexType actual_k_index = k_index % param.actual_rotary_size; + const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index; + + LoadPack cos_vec, sin_vec, out_vec; + + if (param.cos && param.sin) { + cos_vec = *reinterpret_cast(param.cos + sinuous_offset); + sin_vec = *reinterpret_cast(param.sin + sinuous_offset); + } else { + const IndexType actual_ndim = param.rotary_size / rotary_emb_dim; +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + T val = position + * expf(2.0f * static_cast(((actual_k_index >> 1) + i)) + * param.inv_actual_rotary_size * logf(param.theta)); + T cos_val = cosf(val); + T sin_val = sinf(val); + cos_vec.elem[i * 2] = cos_val; + cos_vec.elem[i * 2 + 1] = cos_val; + sin_vec.elem[i * 2] = sin_val; + sin_vec.elem[i * 2 + 1] = sin_val; + } + } + +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + out_vec.elem[i * 2] = x_vec.elem[i * 2] * cos_vec.elem[i * 2] + + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1]; + out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1] + - x_vec.elem[i * 2] * sin_vec.elem[i * 2]; + } + + *(reinterpret_cast(param.out + out_offset)) = out_vec; + } else { + *(reinterpret_cast(param.out + out_offset)) = x_vec; + } + } +} + template __global__ void PlaneKernel( @@ -1172,7 +1253,7 @@ __global__ void PlaneKernel( LoadPack x_vec; IndexType x_offset = param.x_offset; - IndexType out_offset = 0; + IndexType out_offset = param.out_offset; #pragma unroll for (int i = 0; i < num_dims; i++) { x_offset = x_offset + param.x_stride[i] * index[i]; @@ -1199,6 +1280,100 @@ __global__ void PlaneKernel( } } +template +__global__ void PlaneGradKernel( + FusedApplyRotaryEmbParam param) { + for (IndexType offset = threadIdx.x + blockIdx.x * blockDim.x; offset < param.num_elements; + offset += blockDim.x * gridDim.x) { + using LoadPack = cuda::elementwise::Packed; + IndexType temp_offset = offset; + IndexType index[num_dims]; +#pragma unroll + for (int i = 0; i < num_dims - 1; i++) { + IndexType ref_stride = param.ref_stride[i]; + IndexType idx = temp_offset / ref_stride; + index[i] = idx; + temp_offset = temp_offset - idx * ref_stride; + } + index[num_dims - 1] = temp_offset; + + const IndexType b_index = index[0], m_index = index[1], k_index = index[num_dims - 1]; + const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; + const IndexType position_id_offset = b_index * param.position_b_stride + + position_rotate_index * param.position_rotate_stride + + m_index; + + const PositionType position = + param.position_ids ? param.position_ids[position_id_offset] : m_index; + const IndexType actual_k_index = k_index % param.actual_rotary_size; + const IndexType sinuous_offset = position * param.k + actual_k_index; + + T cos_val, sin_val, out_val; + + if (param.cos && param.sin) { + cos_val = *(param.cos + sinuous_offset); + IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; + if (k_index < param.k0) { + offset_ = + (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } else if (k_index < param.k1) { + offset_ = + (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } + + sin_val = *(param.sin + sinuous_offset + offset_); + } else { + T val_cos = position + * expf(2.0f * static_cast(k_index % (param.actual_rotary_size >> 1)) + * param.inv_actual_rotary_size * logf(param.theta)); + + IndexType offset_; // 针对grad, sin_val需要有 size / 2的偏移; + if (k_index < param.k0) { + offset_ = + (param.k0 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } else if (k_index < param.k1) { + offset_ = + (param.k1 - k_index > param.rotate_stride) ? param.rotate_stride : -param.rotate_stride; + } + T val_sin = + position + * expf(2.0f * static_cast((k_index + offset_) % (param.actual_rotary_size >> 1)) + * param.inv_actual_rotary_size * logf(param.theta)); + + cos_val = cosf(val_cos); + sin_val = sinf(val_sin); + } + + LoadPack x_vec; + IndexType x_offset = param.x_offset; + IndexType out_offset = param.out_offset; +#pragma unroll + for (int i = 0; i < num_dims; i++) { + x_offset = x_offset + param.x_stride[i] * index[i]; + out_offset = out_offset + param.out_stride[i] * index[i]; + } + + if (k_index < param.k0) { + x_vec.elem[0] = *(param.x + x_offset); + x_vec.elem[1] = (param.k0 - k_index > param.rotate_stride) + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); + out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; + } else if (k_index < param.k1) { + x_vec.elem[0] = *(param.x + x_offset); + x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride) + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); + out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; + } else { + out_val = *(param.x + x_offset); + } + + *(param.out + out_offset) = out_val; + } +} + template void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin, @@ -1209,7 +1384,8 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, - const IndexType out_h_stride, IndexType num_elements) { + const IndexType out_h_stride, const IndexType out_offset, IndexType num_elements, + const bool is_forward) { const IndexType k0 = rotary_size / rotary_emb_dim, k1 = rotary_size; // TODO: this only support 1d, 2d, rotary postional encoding @@ -1220,7 +1396,7 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin struct FusedApplyRotaryEmbParam param( x, cos, sin, position_ids, out, theta, inv_actual_rotary_size, actual_rotary_size, - rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset); + rotary_size, rotate_stride, num_elements, k, k0, k1, x_offset, out_offset); const IndexType ref_strides[num_dims] = {m * h * k, h * k, k, 1}; const IndexType out_strides[num_dims] = {out_b_stride, out_m_stride, out_h_stride, 1}; @@ -1243,15 +1419,28 @@ void LaunchKernel(ep::CudaStream* stream, const T* x, const T* cos, const T* sin constexpr size_t blk_size = 128; - if (mode == "plane") { - param.num_elements = param.num_elements * PackSize; - PlaneKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + if (is_forward) { + if (mode == "plane") { + param.num_elements = param.num_elements * PackSize; + PlaneKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } else { + IntervalKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } } else { - IntervalKernel - <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( - param); + if (mode == "plane") { + param.num_elements = param.num_elements * PackSize; + PlaneGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } else { + IntervalGradKernel + <<<(param.num_elements + blk_size - 1) / blk_size, blk_size, 0, stream->cuda_stream()>>>( + param); + } } } @@ -1265,7 +1454,8 @@ void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* const IndexType x_b_stride, const IndexType x_m_stride, const IndexType x_h_stride, const IndexType x_offset, const IndexType out_b_stride, const IndexType out_m_stride, - const IndexType out_h_stride, IndexType num_elements) { + const IndexType out_h_stride, const IndexType out_offset, + IndexType num_elements, const bool is_forward) { const auto CheckPackSize = [&](const size_t PackSize) { bool r = (((reinterpret_cast(x) % (sizeof(T) * PackSize)) == 0) && (((rotary_size / rotary_emb_dim) % PackSize) == 0) @@ -1278,19 +1468,19 @@ void DispatchPackSize(ep::CudaStream* stream, const T* x, const T* cos, const T* LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } else if (CheckPackSize(4)) { num_elements /= 4; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } else { num_elements /= 2; LaunchKernel( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } } @@ -1302,8 +1492,9 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si const int64_t b, const int64_t m, const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, const int64_t out_m_stride, - const int64_t out_h_stride) { + const int64_t out_h_stride, const int64_t out_offset, const bool is_forward) { int64_t num_elements = b * m * h * k; + if (num_elements < (1 << 30)) { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, @@ -1312,12 +1503,12 @@ void DispatchIndex(ep::CudaStream* stream, const T* x, const T* cos, const T* si static_cast(x_m_stride), static_cast(x_h_stride), static_cast(x_offset), static_cast(out_b_stride), static_cast(out_m_stride), static_cast(out_h_stride), - static_cast(num_elements)); + static_cast(out_offset), static_cast(num_elements), is_forward); } else { DispatchPackSize( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride, num_elements); + out_m_stride, out_h_stride, out_offset, num_elements, is_forward); } } @@ -1331,17 +1522,18 @@ void DispatchRotaryEmbeddingDimension(ep::CudaStream* stream, const T* x, const const int64_t h, const int64_t k, const int64_t x_b_stride, const int64_t x_m_stride, const int64_t x_h_stride, const int64_t x_offset, const int64_t out_b_stride, - const int64_t out_m_stride, const int64_t out_h_stride) { + const int64_t out_m_stride, const int64_t out_h_stride, + const int64_t out_offset, bool is_forward) { if (rotary_emb_dim == 1) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride); + out_m_stride, out_h_stride, out_offset, is_forward); } else if (rotary_emb_dim == 2) { DispatchIndex( stream, x, cos, sin, position_ids, out, position_shape, x_layout, output_layout, mode, theta, rotary_size, b, m, h, k, x_b_stride, x_m_stride, x_h_stride, x_offset, out_b_stride, - out_m_stride, out_h_stride); + out_m_stride, out_h_stride, out_offset, is_forward); } } @@ -1384,12 +1576,11 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { int64_t k = 0; int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; - ParseDims(out->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); ParseDims(x->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); - + bool is_forward = true; // TODO: hard code num_dims & seems redundant template problem... DispatchRotaryEmbeddingDimension( ctx->stream()->As(), reinterpret_cast(x->dptr()), @@ -1399,7 +1590,71 @@ class FusedApplyRotaryEmbKernel final : public user_op::OpKernel { reinterpret_cast(out->mut_dptr()), position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, - x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride); + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, out_offset, is_forward); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +class FusedApplyRotaryEmbGradKernel final : public user_op::OpKernel { + public: + FusedApplyRotaryEmbGradKernel() = default; + ~FusedApplyRotaryEmbGradKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* cos = nullptr; + user_op::Tensor* sin = nullptr; + user_op::Tensor* position_ids = nullptr; + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + const std::string& mode = ctx->Attr("mode"); + const int64_t tensor_index = ctx->Attr("tensor_index"); + const int64_t k_size = ctx->Attr("k_size"); + const int64_t rotary_size = ctx->Attr("rotary_size"); + const float theta = 1.0f / ctx->Attr("base"); + int rotary_emb_dim = 1; + + size_t dx_byte_size = dx->shape_view().elem_cnt() * sizeof(T); + Memset(ctx->stream(), dx->mut_dptr(), 0, dx_byte_size); + + if (ctx->has_input("cos", 0)) { cos = ctx->Tensor4ArgNameAndIndex("cos", 0); } + + if (ctx->has_input("sin", 0)) { sin = ctx->Tensor4ArgNameAndIndex("sin", 0); } + + if (ctx->has_input("position_ids", 0)) { + position_ids = ctx->Tensor4ArgNameAndIndex("position_ids", 0); + rotary_emb_dim = position_ids->shape_view().At(1); + } + + constexpr size_t ndims = 4; + int64_t b = 0; + int64_t m = 0; + int64_t h = 0; + int64_t k = 0; + int64_t out_b_stride = 0, out_m_stride = 0, out_h_stride = 0, out_offset = 0; + int64_t x_b_stride = 0, x_m_stride = 0, x_h_stride = 0, x_offset = 0; + + ParseDims(dx->shape_view(), x_layout, Optional(), k_size, tensor_index, &b, &m, &h, &k, + &out_b_stride, &out_m_stride, &out_h_stride, &out_offset); + ParseDims(dy->shape_view(), output_layout, Optional(), k_size, 0, &b, &m, &h, &k, + &x_b_stride, &x_m_stride, &x_h_stride, &x_offset); + bool is_forward = false; + // TODO: hard code num_dims & seems redundant template problem... + DispatchRotaryEmbeddingDimension( + ctx->stream()->As(), reinterpret_cast(dy->dptr()), + cos ? reinterpret_cast(cos->dptr()) : nullptr, + sin ? reinterpret_cast(sin->dptr()) : nullptr, + position_ids ? reinterpret_cast(position_ids->dptr()) : nullptr, + reinterpret_cast(dx->mut_dptr()), + position_ids ? position_ids->shape_view().data() : nullptr, x_layout, output_layout, mode, + static_cast(theta), rotary_size, rotary_emb_dim, b, m, h, k, x_b_stride, x_m_stride, + x_h_stride, x_offset, out_b_stride, out_m_stride, out_h_stride, out_offset, is_forward); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -1429,6 +1684,30 @@ REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(half); REGISTER_FUSED_APPLY_ROTARY_EMB_GPU_DTYPE(nv_bfloat16); #endif // CUDA_VERSION >= 11000 +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, position_type) \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value) \ + && (user_op::HobInputSize("position_ids") == 1) \ + && (user_op::HobDataType("position_ids", 0) == GetDataType::value)); + +#define REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(dtype) \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int64_t); \ + REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU(dtype, int32_t); \ + REGISTER_USER_KERNEL("fused_apply_rotary_emb_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value) \ + && (user_op::HobInputSize("position_ids") == 0)); + +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(float); +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(half); +#if CUDA_VERSION >= 11000 +REGISTER_FUSED_APPLY_ROTARY_EMB_GRAD_GPU_DTYPE(nv_bfloat16); +#endif // CUDA_VERSION >= 11000 + } // namespace } // namespace user_op diff --git a/oneflow/user/ops/fused_attention_ops.cpp b/oneflow/user/ops/fused_attention_ops.cpp index 123c09e16fc..d0b0687689b 100644 --- a/oneflow/user/ops/fused_attention_ops.cpp +++ b/oneflow/user/ops/fused_attention_ops.cpp @@ -806,4 +806,189 @@ Maybe ParseSplitAxis(const std::string& layout, bool can_hk_split, int64_t return Maybe::Ok(); } +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + const std::string& mode = ctx->Attr("mode"); + const int64_t rotary_size = ctx->Attr("rotary_size"); + const int64_t k_size = ctx->Attr("k_size"); + const int64_t tensor_index = ctx->Attr("tensor_index"); + + auto Inputs = ctx->inputs(); + + CHECK_OR_RETURN((tensor_index >= 0) && (tensor_index <= 2)) + << "tensor_index should be in range [0, 2]."; + CHECK_OR_RETURN((mode == "interval") || (mode == "plane")) + << "mode should be either \"interval\" or \"plane\"."; + + CHECK_OR_RETURN(output_layout != "BM(H2K)" && output_layout != "BM(H3K)" + && output_layout != "MB(H2K)" && output_layout != "MB(H3K)") + << "output_layout should not be \"BM(H2k)\", \"BM(H3K)\", \"MB(H2K)\", \"MB(H3K)\"."; + + int64_t b = 0, m = 0, h = 0, k = 0; + + JUST(ParseDims(x_desc.shape(), x_layout, Optional(), Optional(k_size), &b, &m, + &h, &k)); + + CHECK_LE_OR_RETURN(rotary_size, k) << "rotary_size should be no more than K of input x."; + + int64_t rotary_emb_dim = 1; + + if (ctx->has_input("position_ids", 0)) { + const Shape& position_id_shape = ctx->InputShape("position_ids", 0); + CHECK_EQ_OR_RETURN(position_id_shape.NumAxes(), 3) + << "ndims of position_ids should be equal to 3, either in form of B1M or B2M."; + CHECK_EQ_OR_RETURN(position_id_shape.At(0), b) + << "1st dim of position_ids should be equal to B."; + CHECK_EQ_OR_RETURN(position_id_shape.At(2), m) + << "3rd dim of position_ids should be equal to M."; + rotary_emb_dim = position_id_shape.At(1); + CHECK_OR_RETURN(rotary_emb_dim == 1 || rotary_emb_dim == 2) + << "2nd dim of position_ids should be 1 or 2."; + } + + const int64_t actual_rotary_size = rotary_size / rotary_emb_dim; + CHECK_EQ_OR_RETURN(actual_rotary_size % 2, 0) + << "rotary_size should be a multiple of 2 * rotary_encoding_dim."; + + bool has_cos = ctx->has_input("cos", 0); + bool has_sin = ctx->has_input("sin", 0); + // TODO: fused_apply_rotary_emb_grad have same logic no matter name + if (has_cos && has_sin) { + const Shape& cos_shape = ctx->InputShape("cos", 0); + const Shape& sin_shape = ctx->InputShape("sin", 0); + CHECK_EQ_OR_RETURN(cos_shape.NumAxes(), 2) + << "The number of dimensions of cos should be equal to 2."; + + CHECK_OR_RETURN(cos_shape == sin_shape) << "The dimensions of cos & sin should be the same."; + CHECK_EQ_OR_RETURN(cos_shape.At(1), actual_rotary_size) + << "The 1st dimension of cos & sin should equal to rotary_size // " + "rotary_embedding_dimension."; + } else if (!has_cos && !has_sin) { + // Do nothing + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + + if (!ctx->has_input("position_ids", 0)) { + if (has_cos && has_sin) { + const Shape& cos_shape = ctx->InputShape("cos", 0); + CHECK_GE_OR_RETURN(cos_shape.At(0), m) + << "M of cos should be no less than M of x if position_ids is not given."; + // K of cos & sin is checked inside ParseDims + } + } + + Shape out_shape = *JUST(LayoutToShape(b, m, h, k, x_layout)); + ctx->SetOutputShape("dx", 0, out_shape); + return Maybe::Ok(); +} + +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FusedApplyRotaryEmbGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + int num_heads = -1; + const int64_t k_size = ctx->Attr("k_size"); + const std::string& x_layout = ctx->Attr("x_layout"); + const std::string& output_layout = ctx->Attr("output_layout"); + + if (x_desc.shape().NumAxes() == 2) { + if (x_layout == "(BM)(HK)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % k_size, 0) + << "shape can not be divided by head dimension size."; + num_heads = x_desc.shape().At(1) / k_size; + } else if (x_layout == "(BM)(H3K)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(1) % (k_size * 3), 0) + << "shape can not be divided by head dimension size."; + num_heads = x_desc.shape().At(1) / (k_size * 3); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else if (x_desc.shape().NumAxes() == 3) { + if (x_layout == "BM(HK)" || x_layout == "MB(HK)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % k_size, 0) + << "shape can not be divided by head dimension size."; + num_heads = x_desc.shape().At(2) / k_size; + } else if (x_layout == "BM(H3K)" || x_layout == "MB(H3K)") { + CHECK_EQ_OR_RETURN(x_desc.shape().At(2) % (k_size * 3), 0) + << "shape can not be divided by head dimension size."; + num_heads = x_desc.shape().At(2) / (k_size * 3); + } else if (x_layout == "(BM)HK") { + num_heads = x_desc.shape().At(1); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else if (x_desc.shape().NumAxes() == 4) { + if (x_layout == "BMHK") { + num_heads = x_desc.shape().At(2); + } else if (x_layout == "BHMK") { + num_heads = x_desc.shape().At(1); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + + const bool can_hk_split = num_heads % ctx->parallel_num() == 0; + int64_t x_b_split_axis = -1; + int64_t x_h_split_axis = -1; + JUST(ParseSplitAxis(x_layout, can_hk_split, &x_b_split_axis, &x_h_split_axis)); + int64_t o_b_split_axis = -1; + int64_t o_h_split_axis = -1; + JUST(ParseSplitAxis(output_layout, can_hk_split, &o_b_split_axis, &o_h_split_axis)); + + if (x_b_split_axis >= 0 && o_b_split_axis >= 0) { + auto builder = ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), o_b_split_axis) + .Split(user_op::OpArg("dx", 0), x_b_split_axis); + if (ctx->user_op_conf().has_input("cos", 0)) + builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); + if (ctx->user_op_conf().has_input("position_ids", 0)) + builder = builder.Split(user_op::OpArg("position_ids", 0), 0); + builder.Build(); + } + if (x_h_split_axis >= 0 && o_h_split_axis >= 0) { + auto builder = ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), o_h_split_axis) + .Split(user_op::OpArg("dx", 0), x_h_split_axis); + if (ctx->user_op_conf().has_input("cos", 0)) + builder = builder.Broadcast(user_op::OpArg("cos", 0)).Broadcast(user_op::OpArg("sin", 0)); + if (ctx->user_op_conf().has_input("position_ids", 0)) + builder = builder.Broadcast(user_op::OpArg("position_ids", 0)); + builder.Build(); + } + + return Maybe::Ok(); +} + +/* static */ Maybe FusedApplyRotaryEmbGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("x", 0); + + bool has_sinuous = ctx->has_input("cos", 0); + + if (has_sinuous) { + const user_op::TensorDesc& cos_desc = ctx->InputTensorDesc("cos", 0); + const user_op::TensorDesc& sin_desc = ctx->InputTensorDesc("sin", 0); + + CHECK_EQ_OR_RETURN(cos_desc.data_type(), first_in_desc.data_type()) + << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) + << ", but got " << DataType_Name(cos_desc.data_type()); + CHECK_EQ_OR_RETURN(sin_desc.data_type(), first_in_desc.data_type()) + << "InferDataType Failed. Expected " << DataType_Name(first_in_desc.data_type()) + << ", but got " << DataType_Name(sin_desc.data_type()); + } + + user_op::TensorDesc* out_desc = ctx->MutOutputTensorDesc("dx", 0); + out_desc->set_data_type(first_in_desc.data_type()); + + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/python/oneflow/test/modules/test_fused_rotary_embedding.py b/python/oneflow/test/modules/test_fused_rotary_embedding.py index 68e8838fddb..dc70cb24ec9 100644 --- a/python/oneflow/test/modules/test_fused_rotary_embedding.py +++ b/python/oneflow/test/modules/test_fused_rotary_embedding.py @@ -22,7 +22,123 @@ import numpy as np import math +# tensor version: +def plane_shuffle_tensor(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return flow.cat((-x2, x1), dim=-1) + + +def shuffle_adjacent_two_elem_tensor(x): + y = x.clone() + for i in range(x.shape[-1] // 2): + y[..., 2 * i] = -x[..., 2 * i + 1] + y[..., 2 * i + 1] = x[..., 2 * i] + return y + + +def parseDims_tensor(dims, x_layout): + B, M, H, K = 1, 1, 1, 1 + merged_dims = dims + if x_layout == "BHMK": + B, H, M, K = dims + elif x_layout == "BMHK": + B, M, H, K = dims + elif x_layout == "MBHK": + M, B, H, K = dims + elif x_layout == "BM(HK)": + B, M, H, K = dims + merged_dims = [dims[0], dims[1], dims[2] * dims[3]] # merge H and K + elif x_layout == "MB(HK)": + M, B, H, K = dims + merged_dims = [dims[0], dims[1], dims[2] * dims[3]] + elif x_layout == "BM(H3K)": + B, M, H, K = dims + merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] # merge and scale + elif x_layout == "MB(H3K)": + M, B, H, K = dims + merged_dims = [dims[0], dims[1], 3 * dims[2] * dims[3]] + return B, M, H, K, merged_dims + + +def naive_embedding_tensor( + x, + cos, + sin, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, +): + naive_out = None + if mode == "plane": + if rotary_ndims == 2: + y1 = plane_shuffle_tensor(x[..., : rotary_size // 2]) + y2 = plane_shuffle_tensor(x[..., rotary_size // 2 : rotary_size]) + y3 = x[..., rotary_size:] + y = flow.cat((y1, y2, y3), dim=-1) + else: + y1 = plane_shuffle_tensor(x[..., :rotary_size]) + y2 = x[..., rotary_size:] + y = flow.cat((y1, y2), dim=-1) + else: + y = shuffle_adjacent_two_elem_tensor(x) + + if x_layout == "BHMK": + naive_out = x * cos + y * sin + elif x_layout == "BMHK": + naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( + dims + ) * sin.reshape( + [B, M, 1, K] + ) # un-merge + elif x_layout == "MBHK" or x_layout == "MB(HK)": + naive_out = x.reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y.reshape(dims) * sin.permute([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) # un-merge + elif x_layout == "BM(HK)": + naive_out = x.reshape(dims) * cos.reshape([B, M, 1, K]) + y.reshape( + dims + ) * sin.reshape( + [B, M, 1, K] + ) # un-merge + elif x_layout == "BM(H3K)": + out0 = x[..., 0, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 0, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + out1 = x[..., 1, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 1, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + out2 = x[..., 2, :].reshape(dims) * cos.reshape([B, M, 1, K]) + y[ + ..., 2, : + ].reshape(dims) * sin.reshape([B, M, 1, K]) + + naive_out = flow.cat((out0, out1, out2), dim=-1) + elif x_layout == "MB(H3K)": + out0 = x[..., 0, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 0, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) + out1 = x[..., 1, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 1, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) + out2 = x[..., 2, :].reshape(dims) * cos.permute([2, 0, 1, 3]).reshape( + [M, B, 1, K] + ) + y[..., 2, :].reshape(dims) * sin.permute([2, 0, 1, 3]).reshape([M, B, 1, K]) + + naive_out = flow.cat((out0, out1, out2), dim=-1) + + return naive_out + + +# numpy version: def plane_shuffle(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return np.concatenate((-x2, x1), axis=-1) @@ -251,6 +367,42 @@ def _test_without_position( mode, ) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # check naive_out_tensor and naive_out; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # get naive_out_grad + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_cos = np.array( [ [ @@ -289,7 +441,7 @@ def _test_without_position( for m in range(M) ] ).reshape(M, rotary_size // rotary_ndims) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) @@ -334,7 +486,7 @@ def _test_without_position( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -346,12 +498,28 @@ def _test_without_position( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + # get fused_out_grad + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_x.grad + + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -434,8 +602,43 @@ def _test_without_position_sinuous( rotary_ndims, mode, ) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + # check naive_out_tensor and naive_out; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # get naive_out_grad + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) if x_layout == "BM(H3K)": out0 = flow._C.fused_apply_rotary_emb( @@ -478,7 +681,7 @@ def _test_without_position_sinuous( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -490,12 +693,27 @@ def _test_without_position_sinuous( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + # get fused_out_grad + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_x.grad + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -588,6 +806,42 @@ def _test_with_position_sinuous( mode, ) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # check naive_out_tensor and naive_out; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # get naive_out_grad; + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + fused_cos = np.array( [ [ @@ -627,7 +881,7 @@ def _test_with_position_sinuous( ] ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_cos = flow.tensor(fused_cos, dtype=dtype, device=device) fused_sin = flow.tensor(fused_sin, dtype=dtype, device=device) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) @@ -673,7 +927,7 @@ def _test_with_position_sinuous( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -685,12 +939,27 @@ def _test_with_position_sinuous( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + # get fused_out_grad; + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_x.grad + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -778,7 +1047,42 @@ def _test_with_position( mode, ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # check naive_out_tensor and naive_out; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + # get naive_out_grad + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "BM(H3K)": @@ -822,7 +1126,7 @@ def _test_with_position( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -834,18 +1138,33 @@ def _test_with_position( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() - + ) + # get fused_out_grad + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_x.grad + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) ) +# TODO: with cos & sin test def _test_plane( test_case, x_layout, mode, base, rotary_size, dims, rotary_ndims, dtype, device ): @@ -935,7 +1254,42 @@ def _test_plane( mode, ) - fused_x = flow.tensor(x, dtype=dtype, device=device) + naive_x_tensor = flow.tensor( + naive_x, dtype=dtype, device=device, requires_grad=True + ) + naive_cos_tensor = flow.tensor(naive_cos, dtype=dtype, device=device) # 不用grad + naive_sin_tensor = flow.tensor(naive_sin, dtype=dtype, device=device) + naive_out_tensor = naive_embedding_tensor( + naive_x_tensor, + naive_cos_tensor, + naive_sin_tensor, + x_layout, + B, + M, + H, + K, + dims, + merged_dims, + rotary_size, + rotary_ndims, + mode, + ) + + # check naive_out_tensor and naive_out; + test_case.assertTrue( + np.allclose( + naive_out.reshape(merged_dims), + naive_out_tensor.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + naive_out_backward = naive_out_tensor.sum() + naive_out_backward.backward() + naive_out_grad = naive_x_tensor.grad + + fused_x = flow.tensor(x, dtype=dtype, device=device, requires_grad=True) fused_position_ids = flow.tensor(position_ids, dtype=flow.int32, device=device) if x_layout == "MB(H3K)": @@ -979,7 +1333,7 @@ def _test_plane( tensor_index=2, ) - fused_out = np.concatenate((out0, out1, out2), axis=-1) + fused_out = flow.cat((out0, out1, out2), dim=-1) else: fused_out = flow._C.fused_apply_rotary_emb( fused_x, @@ -991,12 +1345,26 @@ def _test_plane( base=base, rotary_size=rotary_size, mode=mode, - ).numpy() + ) + fused_out_backward = fused_out.sum() + fused_out_backward.backward() + fused_out_grad = fused_x.grad + # test forward test_case.assertTrue( np.allclose( naive_out.reshape(merged_dims), - fused_out.reshape(merged_dims), + fused_out.numpy().reshape(merged_dims), + atol=5e-2, + rtol=5e-3, + ) + ) + + # test backward + test_case.assertTrue( + np.allclose( + naive_out_grad.numpy().reshape(merged_dims), + fused_out_grad.numpy().reshape(merged_dims), atol=5e-2, rtol=5e-3, ) @@ -1014,10 +1382,11 @@ def _test_plane( @flow.unittest.skip_unless_1n1d() class TestFusedRotaryEmbedding(flow.unittest.TestCase): # because rule no.2, kernels without cos&sin cannot work under specific x_layout + def test_fused_rotary_embedding_op_plane(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_plane] - args_dict["x_layout"] = ["MB(H3K)", "MB(HK)"] + args_dict["x_layout"] = ["MB(HK)", "MB(H3K)"] args_dict["mode"] = ["plane"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4, 8] @@ -1034,7 +1403,7 @@ def test_fused_rotary_embedding_op_plane(test_case): def test_fused_rotary_embedding_op_interval_2d(test_case): args_dict = OrderedDict() args_dict["test_fun"] = [_test_with_position, _test_with_position_sinuous] - args_dict["x_layout"] = ["BMHK"] + args_dict["x_layout"] = ["BMHK", "BM(H3K)"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4] @@ -1056,7 +1425,7 @@ def test_fused_rotary_embedding_op_interval_1d(test_case): _test_with_position, _test_with_position_sinuous, ] - args_dict["x_layout"] = ["BMHK"] + args_dict["x_layout"] = ["BMHK", "BM(H3K)"] args_dict["mode"] = ["interval"] args_dict["base"] = [1e1] args_dict["rotary_size"] = [4]