diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 6cd7b3d9f62..31047c03b20 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1631,6 +1631,13 @@ struct ConvInvscale e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float / scale_in_ / scale_wei_ / scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1656,6 +1663,13 @@ struct ConvScale e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float * scale_in_ * scale_wei_ * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1683,6 +1697,15 @@ struct ConvScaleRelu e = type_convert(x * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + float x; + Relu{}.template operator()(x, c_float * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp index aecf519c10c..5210265cefc 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp @@ -10,49 +10,55 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward data convolution kernel working with packed (contiguous) tensors -// Computes gradients w.r.t. input from output gradients and weights -// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], -// output[G][N][K][spatial] +// Optimized backward data convolution kernel working with packed (contiguous) tensors with +// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes +// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial] template -__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - const OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, + const WeiDataType* const* __restrict__ p_weis, + const OutDataType* const* __restrict__ p_outs, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t x = 0; x < X; ++x) { @@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t wo = w_tmp / stride_x; if(wo >= 0 && wo < Wo) { - const OutDataType* out_gnk = out_gn; - const WeiDataType* wei_gkc = wei_g + c * wei_stride_c; + // Pointers at current filter position + const OutDataType* output_grad_g_n_k = output_grad_g_n; + const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c; for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnk[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkc[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_g_n_k, + p_outs + 1, + g * out_stride_g + n * out_stride_n, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_g_k_c, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } } } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val; } } @@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t y = 0; y < Y; ++y) { @@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkh = out_gn + ho * out_stride_h; - const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h; + const WeiDataType* weight_at_c_y = + weight_g + c * wei_stride_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnkh[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra + // A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_outs + 1, + g * out_stride_g + n * out_stride_n + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op(in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + hi * p_d_strides[0][3] + + wi * p_d_strides[0][4]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] = in_val; } @@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t z = 0; z < Z; ++z) { @@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t do_idx = d_tmp / stride_z; if(do_idx >= 0 && do_idx < Do) { - const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d; - const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z; + // Pointers at current spatial depth + const OutDataType* output_grad_at_d = + output_grad_g_n + do_idx * out_stride_d; + const WeiDataType* weight_at_c_z = + weight_g + c * wei_stride_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { @@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; + const WeiDataType* weight_at_c_z_y = + weight_at_c_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, - out_gnkdh[k * out_stride_k + wo]); - wei_op(wei_val, - wei_gkczy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation + // with extra A tensors + detail::apply_multi_tensor_elementwise_op< + NumAExtra>(out_val, + out_op, + output_grad_at_d_h, + p_outs + 1, + g * out_stride_g + + n * out_stride_n + + do_idx * out_stride_d + + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with + // extra B tensors + detail::apply_multi_tensor_elementwise_op< + NumBExtra>( + wei_val, + wei_op, + weight_at_c_z_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + hi * in_stride_h + wi] = in_val; } } } -// GPU reference backward data convolution - takes ConvParam directly -template -void naive_conv_bwd_data(TIn* p_in, - const TWei* p_wei, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility +void naive_conv_bwd_data_multi_abd( + TIn* p_in, + const std::array& p_weis, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in, // Allocate packed buffers SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); - TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + + std::vector out_packed_bufs; + out_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + out_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_outs_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_outs_packed[i] = static_cast(out_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in, // Pack output and weight tensors to contiguous layout (inputs to bwd data) constexpr int block_size = 256; - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs, + p_outs_packed.data(), + (NumAElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in, if(ndim == 1) { - naive_conv_bwd_data_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_data_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in, } else // 3D { - naive_conv_bwd_data_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_data - now a zero-overhead wrapper +template +inline void naive_conv_bwd_data(TIn* p_in, + const TWei* p_wei, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_weis = {p_wei}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in, + p_weis, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp index f46b072baa3..8cee2e2b770 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp @@ -10,49 +10,58 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward weight convolution kernel working with packed (contiguous) tensors +// Optimized backward weight convolution kernel working with packed (contiguous) tensors with +// multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial], // weight_grad[G][K][C][filter] // Computes gradient with respect to weights template -__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in, - WeiDataType* __restrict__ p_wei_grad, - const OutDataType* __restrict__ p_out_grad, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void +naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins, + WeiDataType* __restrict__ p_wei_grad, + const OutDataType* const* __restrict__ p_out_grads, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gn[wi]); - out_op(out_val, out_gn_k[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_n_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle output gradient element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_n_k, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val; } } @@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h; + // Pointers at current spatial height + const InDataType* input_at_h = input_at_n_c + hi * in_stride_h; + const OutDataType* output_grad_at_h = + output_grad_at_n_k + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - out_op(out_val, out_gn_kh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra B + // tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } @@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op(wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + y * p_d_strides[0][3] + + x * p_d_strides[0][4]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y + x] = wei_val; } @@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t do_idx = 0; do_idx < Do; ++do_idx) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_n_c + di * in_stride_d; + const OutDataType* output_grad_at_d = + output_grad_at_n_k + do_idx * out_stride_d; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - out_op(out_val, out_gn_kdh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra + // B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_d_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + do_idx * out_stride_d + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } @@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z + y * wei_stride_y + x] = wei_val; } } } -// GPU reference backward weight convolution - takes ConvParam directly -template -void naive_conv_bwd_weight(const TIn* p_in, - TWei* p_wei_grad, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility +void naive_conv_bwd_weight_multi_abd( + const std::array& p_ins, + TWei* p_wei_grad, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in, out_total *= l; // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); + std::vector out_grad_packed_bufs; + out_grad_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } + TWei* p_wei_grad_packed = static_cast(wei_grad_packed_buf.GetDeviceBuffer()); - TOut* p_out_grad_packed = static_cast(out_grad_packed_buf.GetDeviceBuffer()); + + std::array p_out_grads_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_out_grads_packed[i] = static_cast(out_grad_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in, // Pack input and output_grad tensors to contiguous layout (inputs to bwd weight) constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], + p_out_grads_packed[i], + d_out_lengths, + d_out_strides, + dim_count, + out_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs, + p_out_grads_packed.data(), + (NumBElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in, if(ndim == 1) { - naive_conv_bwd_weight_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_weight_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else // 3D { - naive_conv_bwd_weight_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_weight - now a zero-overhead wrapper +template +inline void +naive_conv_bwd_weight(const TIn* p_in, + TWei* p_wei_grad, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_wei_grad, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp index 131b632a255..7bf9b499980 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp @@ -10,48 +10,56 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized convolution kernel working with packed (contiguous) tensors +// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], // output[G][N][K][spatial] template -__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_fwd_packed_multi_abd( + const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra) + const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra) + const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers + const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays + OutDataType* __restrict__ p_out, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gc = in_g + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gc[wi]); - wei_op(wei_val, wei_gkc[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val; } } @@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const InDataType* input_at_h = input_at_c + hi * in_stride_h; + const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - wei_op(wei_val, wei_gkcy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } @@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op(out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + ho * p_d_strides[0][3] + + wo * p_d_strides[0][4]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] = out_val; } @@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t z = 0; z < Z; ++z) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_c + di * in_stride_d; + const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - wei_op(wei_val, wei_gkczy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_z_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } @@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + ho * out_stride_h + wo] = out_val; } } } -// GPU reference convolution - takes ConvParam directly -template -void naive_conv_fwd(const TIn* p_in, - const TWei* p_wei, - TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility +void naive_conv_fwd_multi_abd( + const std::array& p_ins, + const std::array& p_weis, + const std::array& p_ds, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in, for(auto l : out_lengths) out_total *= l; - // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); + // Allocate packed buffers for all A and B tensors + // Use separate allocations to avoid copy assignment issues with RAII wrapper + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); + // Get packed buffer pointers + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); // Compute strides and allocate device arrays for pack/unpack @@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in, // Pack input and weight tensors to contiguous layout constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + // Pack all A tensors + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + // Pack all B tensors + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + // NOTE: D tensors are NOT packed - they are used directly with their original strides + // to support broadcasting (e.g., BiasGK layout with zero strides) + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + // Allocate and copy strides to device + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + // D tensors use original pointers (not packed) to support broadcasting + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in, if(ndim == 1) { - naive_conv_fwd_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in, } else if(ndim == 2) { - naive_conv_fwd_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in, } else // 3D { - naive_conv_fwd_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_fwd - now a zero-overhead wrapper +template +inline void naive_conv_fwd(const TIn* p_in, + const TWei* p_wei, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_weis = {p_wei}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_weis, + p_ds, + p_out, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp index 0a7b58b3102..50b65357a24 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp @@ -22,9 +22,39 @@ struct SimpleDeviceMem HIP_CHECK_ERROR(hipMalloc(static_cast(&p_mem_), mem_size)); } + // Delete copy operations (resource should not be copied) + SimpleDeviceMem(const SimpleDeviceMem&) = delete; + SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete; + + // Define move operations + SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_) + { + other.p_mem_ = nullptr; + } + + SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept + { + if(this != &other) + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + p_mem_ = other.p_mem_; + other.p_mem_ = nullptr; + } + return *this; + } + void* GetDeviceBuffer() { return p_mem_; } - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + ~SimpleDeviceMem() + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + } void* p_mem_; }; @@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src, } } +namespace detail { + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result, + Op&& element_op, + const DataType* const* tensor_ptrs, + long_index_t element_offset, + std::index_sequence) +{ + element_op(result, tensor_ptrs[Is][element_offset]...); +} + +// Generic helper for A and B tensors (works in all directions) +template +__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result, + Op&& element_op, + const DataType* primary_ptr, + const DataType* const* extra_ptrs, + long_index_t extra_base_offset, + long_index_t element_offset) +{ + const DataType* tensor_ptrs[NumExtraTensors + 1]; + tensor_ptrs[0] = primary_ptr; + + static_for<1, NumExtraTensors + 1, 1>{}( + [&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; }); + + apply_multi_tensor_impl(result, + element_op, + tensor_ptrs, + element_offset, + std::make_index_sequence{}); +} + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out, + Op&& element_op, + float computed_value, + const float* d_values, + std::index_sequence) +{ + float temp_out; + element_op(temp_out, computed_value, d_values[Is]...); + result_out = type_convert(temp_out); +} + +// Specialized helper for D tensors with stride calculations and float conversion +template +__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out, + Op&& element_op, + float computed_value, + const DDataType* const* p_ds, + const index_t* const* p_d_strides, + index_t g, + index_t n, + index_t c_or_k, + long_index_t spatial_linear_index) +{ + if constexpr(NumDTensors == 0) + { + element_op(result_out, computed_value); + } + else + { + float d_values[NumDTensors]; + + // Compute all D tensor indices and convert to float + static_for<0, NumDTensors, 1>{}([&](auto i) { + const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] + + c_or_k * p_d_strides[i][2] + spatial_linear_index; + d_values[i] = type_convert(p_ds[i][d_idx]); + }); + + apply_d_tensor_impl(result_out, + element_op, + computed_value, + d_values, + std::make_index_sequence{}); + } +} + +} // namespace detail + } // namespace ref } // namespace ck diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index a0f9b9ac256..bf5ffcb5d21 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -17,6 +17,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" namespace ck { namespace profiler { @@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(output.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); - if(do_verification) + // profile device Conv instances + bool pass = true; + + if(do_verification == 1) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData gpu_ref_input(in_g_n_c_wis_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * + input_device_result.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data()); + } + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData gpu_ref_output(out_g_n_k_wos_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data()); + } using DeviceOp = ck::tensor_operation::device::DeviceConvFwd(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 50cd58eec37..2a282edbc83 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, bias_device_buf.ToDevice(bias.mData.data()); // run reference op - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + if constexpr(BiasGK) + { + // For GK bias layout: G*K, zero strides for N and spatial dimensions + d_strides_vec[0] = K; + d_strides_vec[1] = 0; + d_strides_vec[2] = 1; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_strides_vec[3 + i] = 0; + } + } + else + { + // Full GNKHW layout - same as output + ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin()); + } + + std::array d_ptrs = { + reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp index 3f4905c1105..b439428cda7 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -22,6 +22,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl( wei_device_buf.ToDevice(weight.mData.data()); d_device_buf.ToDevice(d_tensor.mData.data()); - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< NDimSpatial, InDataType, @@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl( host_output(idx) = ck::type_convert(out_val); }); } + else if(do_verification == 2) + { + // GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + // D tensor has same layout as output + ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin()); + + std::array d_ptrs = { + reinterpret_cast(d_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DDataType>( // Explicitly specify D tensor type + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index acdc937a333..9444996c254 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -7,6 +7,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" @@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, std::cout << "scale_out: " << scale_out << std::endl; // run reference op - if(do_verification) + if(do_verification == 1) { std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; @@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, } }); } + else if(do_verification == 2) + { + // GPU reference + // WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference + // Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md) + if constexpr(std::is_same_v && + std::is_same_v) + { + // Compute conv to CShuffleDataType (float), then post-process on CPU + DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_c_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + PassThrough{}); + + ck::hip_check_error(hipDeviceSynchronize()); + + Tensor gpu_c(out_g_n_k_wos_desc); + gpu_ref_c_dev.FromDevice(gpu_c.mData.data()); + + // Post-process on CPU to match CPU reference behavior + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(gpu_c(idx)); + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + }); + } + else + { + // Normal path for non-int8 or non-Scale cases + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * + device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(host_output.mData.data()); + } + } std::string best_op_name; float best_avg_time = 0; @@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, best_gb_per_sec = gb_per_sec; } - if(do_verification) + if(do_verification == 1) { out_device_buf.FromDevice(device_output.mData.data()); @@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, << std::endl; } } + else if(do_verification == 2) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = + pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and GPU ref results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "gpu_ref_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp index 98f466a2b3d..3e4eb07a64e 100644 --- a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp @@ -46,7 +46,7 @@ class TestConvndBwdData : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/convnd_fwd/convnd_fwd_xdl.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp index a2fdcaf8709..0377b01bb2c 100644 --- a/test/convnd_fwd/convnd_fwd_xdl.cpp +++ b/test/convnd_fwd/convnd_fwd_xdl.cpp @@ -47,7 +47,7 @@ class TestConvndFwd : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/gpu_reference/CMakeLists.txt b/test/gpu_reference/CMakeLists.txt index 443818feb3a..d1c39088499 100644 --- a/test/gpu_reference/CMakeLists.txt +++ b/test/gpu_reference/CMakeLists.txt @@ -4,6 +4,9 @@ add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp) target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility) +add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp) +target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility) + add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp) target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility) diff --git a/test/gpu_reference/gpu_reference_utils.hpp b/test/gpu_reference/gpu_reference_utils.hpp index fc017c8734f..88306d51a49 100644 --- a/test/gpu_reference/gpu_reference_utils.hpp +++ b/test/gpu_reference/gpu_reference_utils.hpp @@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType } } +// Forward convolution with D tensor support +template +bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& d_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& d_dev, + DeviceMem& output_dev, + OutElementOp out_element_op) +{ + using InElementOp = tensor_operation::element_wise::PassThrough; + using WeiElementOp = tensor_operation::element_wise::PassThrough; + + // Create D tensor lengths and strides for GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + d_lengths_vec[0] = params.G_; + d_lengths_vec[1] = params.N_; + d_lengths_vec[2] = params.K_; + for(index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); + } + + std::vector d_strides_vec = + ref::compute_conv_tensor_strides(d_lengths_vec, params.num_dim_spatial_); + + std::array d_ptrs = { + reinterpret_cast(d_dev.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + // Call GPU reference with D tensor + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer())}; + + ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> d_tensors_ref = {d_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + InElementOp{}, + WeiElementOp{}, + out_element_op, + {}, // A tensors + {}, // B tensors + d_tensors_ref); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + +// Forward convolution with multiple A/B tensor support +template +bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& a_extra_cpu, + const Tensor& b_extra_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& a_extra_dev, + DeviceMem& b_extra_dev, + DeviceMem& output_dev, + InElementOp in_element_op, + WeiElementOp wei_element_op) +{ + using OutElementOp = tensor_operation::element_wise::PassThrough; + + // Call GPU reference with extra A and B tensors + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer()), + reinterpret_cast(a_extra_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer()), + reinterpret_cast(b_extra_dev.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + OutElementOp{}); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> a_tensors_ref = {a_extra_cpu}; + std::array, 1> b_tensors_ref = {b_extra_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + in_element_op, + wei_element_op, + OutElementOp{}, + a_tensors_ref, + b_tensors_ref, + {}); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + } // namespace test } // namespace ck diff --git a/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp new file mode 100644 index 00000000000..ebe1e9695c3 --- /dev/null +++ b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "gpu_reference_utils.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using namespace ck; +using ck::test::ConvKernelType; + +// ==================== D Tensor (Bias) Tests ==================== + +template +bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::AddClamp; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(bias, bias_dev); + + // Test with AddClamp (bias operation with clamping) + AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias) +{ + auto params = test::conv_test_shapes::get_3d_small(); + bool result = test_conv_gpu_ref_with_bias<3, + float, + float, + float, + tensor_layout::convolution::GNCDHW, + tensor_layout::convolution::GKCZYX, + tensor_layout::convolution::GNKDHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g4(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== D Tensor (Bilinear) Tests ==================== + +template +bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::Bilinear; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor d_tensor(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(d_tensor, d_dev); + + // Test with Bilinear: y = alpha * conv_result + beta * d_tensor + Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bilinear<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== Multiple A/B (ScaleAdd) Tests ==================== + +template +bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::ScaleAdd; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input) + Tensor b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight) + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType)); + DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(a_extra, a_extra_dev); + test::initialize_and_copy_tensor(b_extra, b_extra_dev); + + // Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1 + ScaleAdd in_element_op(2.0f); // scale factor for input + ScaleAdd wei_element_op(1.5f); // scale factor for weight + + return test::test_conv_fwd_with_multi_ab_impl(params, + input, + weight, + a_extra, + b_extra, + input_dev, + weight_dev, + a_extra_dev, + b_extra_dev, + output_dev, + in_element_op, + wei_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp index b45f204b40d..ea7289d6bf6 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -63,37 +63,62 @@ class TestGroupedConvndBwdData : public ::testing::Test Tensor& out, Tensor& d) { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(in_g_n_c_wis_desc, d_lengths, d_strides); + + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; + + DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); + + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData(); - - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - Bilinear{alpha, beta}, - WeiElementOp{}, - OutElementOp{}, - {}, - {}, - d_tensors); - - ref_invoker.Run(ref_argument); + ck::ref::naive_conv_bwd_data_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InDataType>( + static_cast(in_device_buf.GetDeviceBuffer()), + {static_cast(wei_device_buf.GetDeviceBuffer())}, + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{alpha, beta}, + WeiElementOp{}, + OutElementOp{}); + + in_device_buf.FromDevice(in_host.mData.data()); } bool PerformConvDataBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp index 84d013bca7d..f1f985883c6 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test void RunReference(ck::utils::conv::ConvParam& conv_param, Tensor& in_host, - Tensor& wei, - Tensor& out) + DeviceMem& wei_device_buf, + DeviceMem& out_device_buf) { - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise - Tensors*/ - {}; - - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{alpha}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); + // GPU reference + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); + + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(in_host.mData.data()); } bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) @@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); - in_device_buf.ToDevice(in_device.mData.data()); out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + RunReference(conv_param, in_host, wei_device_buf, out_device_buf); + std::array out_lengths{}; std::array out_strides{}; std::array wei_lengths{}; @@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - RunReference(conv_param, in_host, wei, out); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD& out, Tensor& d) { - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdWeight{}; - - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in, - wei_host, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{}, - WeiElementOp{alpha, beta}, - OutElementOp{}, - {}, - {}, - d_tensors); - - ref_invoker.Run(ref_argument); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(wei_g_k_c_xs_desc, d_lengths, d_strides); + + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; + + DeviceMem d_device_buf(sizeof(WeiDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); + + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_host.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_weight_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + WeiDataType>( + {static_cast(in_device_buf.GetDeviceBuffer())}, + static_cast(wei_device_buf.GetDeviceBuffer()), + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{}, + WeiElementOp{alpha, beta}, + OutElementOp{}); + + wei_device_buf.FromDevice(wei_host.mData.data()); } bool PerformConvWeightBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp index 1b37f5eb4ee..645aab01510 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -66,10 +66,10 @@ class TestGroupedConvndFwdBilinear : public ::testing::Test OutDataType, AComputeType, BComputeType, - IndexType>(true, // do_verification + IndexType>(2, // do_verification 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param, bilinear_op); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index 199a50f0fdf..e78e61f7076 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -24,6 +24,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" using I8 = int8_t; using F16 = ck::half_t; @@ -131,39 +132,34 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, wei_device_buf.ToDevice(weight.mData.data()); wei_bias_device_buf.ToDevice(weight_bias.mData.data()); - // Run reference op + // Run GPU reference if(do_verification) { - const std::array, NumAs - 1> elementwise_a_tensors = {input_bias}; - const std::array, NumBs - 1> elementwise_b_tensors = {weight_bias}; - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); - - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weight, - host_output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op, - elementwise_a_tensors, - elementwise_b_tensors); - - // init host output to zero - host_output.SetZero(); - - ref_invoker.Run(ref_argument); + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer()), + reinterpret_cast(in_bias_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer()), + reinterpret_cast(wei_bias_device_buf.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + ck::ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); } std::string best_op_name; diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index d1706d4cec6..68a8b016e35 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, false /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp index fef485a9506..2c04b52b4fa 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -50,7 +50,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, Clamp>( - true, // do_verification + 2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp index a78a17cbf49..78cfe126a30 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, true /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp index b4179cae627..b2a9cff231e 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -58,10 +58,10 @@ class TestGroupedConvndFwdScale : public ::testing::Test OutDataType, ck::tensor_operation::element_wise::Scale, InDataType, - InDataType>(true, // do_verification + InDataType>(2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param); } EXPECT_TRUE(pass);