diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp index fdbfa7c4e1..359b12c4a3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp @@ -7,43 +7,52 @@ #pragma once #include "ck_tile/builder/reflect/conv_description.hpp" -#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/instance_to_conv_traits.hpp" namespace ck_tile::reflect { -/// @brief Factory function to create ConvDescription from a convolution instance type -/// @tparam Instance The convolution instance type (must have ConvTraits) -/// @return A ConvDescription object populated with the instance's configuration details -template +/// @brief Concept to check if an Instance type has conv traits +template +concept HasConvTraits = requires { + { conv::instance_to_conv_traits() }; +}; + +/// Factory function to create ConvDescription from a convolution instance type +/// Instance The convolution instance type +/// A ConvDescription object populated with the instance's configuration details +/// +/// TODO: Fix ConvDescription to just use the ConvTraits directly. +template + requires HasConvTraits conv::ConvDescription describe() { - using Traits = conv::ConvTraits; + const auto traits = conv::instance_to_conv_traits(); return conv::ConvDescription( conv::ConvSignatureInfo{ - .spatial_dim = Traits::spatial_dim, - .direction = Traits::direction, - .input_layout = Traits::layout[0], - .weight_layout = Traits::layout[1], - .output_layout = Traits::layout[2], - .data_type = Traits::data_type, - .input_element_op = Traits::input_element_op, - .weight_element_op = Traits::weight_element_op, - .output_element_op = Traits::output_element_op, + .spatial_dim = traits.spatial_dim, + .direction = traits.direction, + .input_layout = traits.layout[0], + .weight_layout = traits.layout[1], + .output_layout = traits.layout[2], + .data_type = traits.data_type, + .input_element_op = traits.input_element_op, + .weight_element_op = traits.weight_element_op, + .output_element_op = traits.output_element_op, }, conv::GemmAlgorithmInfo{ - .thread_block_size = Traits::thread_block_size, - .tile_dims = Traits::tile_dims, - .warp_gemm = Traits::warp_gemm, - .a_tile_transfer = Traits::a_tile_transfer, - .b_tile_transfer = Traits::b_tile_transfer, - .c_tile_transfer = Traits::c_tile_transfer, - .pipeline_version = Traits::pipeline_version, - .pipeline_scheduler = Traits::pipeline_scheduler, - .conv_specialization = Traits::conv_specialization, - .padding = Traits::gemm_padding, + .thread_block_size = traits.thread_block_size, + .tile_dims = traits.tile_dims, + .warp_gemm = traits.warp_gemm, + .a_tile_transfer = traits.a_tile_transfer, + .b_tile_transfer = traits.b_tile_transfer, + .c_tile_transfer = traits.c_tile_transfer, + .pipeline_version = traits.pipeline_version, + .pipeline_scheduler = traits.pipeline_scheduler, + .conv_specialization = traits.conv_specialization, + .padding = traits.gemm_padding, }, - []() { return reflect::instance_string(); }); + []() { return reflect::instance_string(); }); } } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 8caa11618e..451a74be34 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,664 +1,109 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Runtime-accessible convolution kernel configuration data structure +// +// This file defines ConvTraits, a pure data structure that captures the complete +// configuration of a convolution kernel in a domain-specific abstraction, without +// requiring knowledge of the underlying kernel instance implementation details. +// +// ## Purpose and Design +// +// ConvTraits provides type erasure for convolution kernel configurations, allowing +// for reflection of convolution kernel objects. The struct represents kernel +// traits in terms of convolution-specific concepts for AMD GPUs rather than raw +// template parameters. +// +// ## Architecture and Usage +// +// ConvTraits sits at the center of the reflection system: +// +// 1. **Population**: Values are created by `instance_to_conv_traits()` template +// specializations that extract configuration from compile-time InstanceTraits +// +// 2. **Consumption**: Used by ConvDescription to provide human-readable descriptions +// of kernel configurations for debugging, logging, and documentation +// +// ## Structure Organization +// +// The struct separates kernel configuration into two logical categories: +// +// - **Signature Information**: Defines what the kernel computes (direction, layouts, +// data types, elementwise operations, specializations) +// +// - **Algorithm Information**: Defines how the kernel computes (thread block size, +// tile dimensions, memory access patterns, pipeline configuration) +// +// ## Evolution and Extensibility +// +// ConvTraits is designed to evolve through composition (not inheritance): +// +// - Currently supports XDL forward convolution kernels +// - Will extend to the other forward convolutions +// - Will be extended to cover backward data and backward weight convolutions +// - Will incorporate fusion operations and additional specializations +// - Uses std::optional and std::variant for optional/variant fields +// - Eventually will generalize to KernelTraits for GEMM, flash attention, etc. + #pragma once -#include -#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/utility/pipeline_enum.hpp" -#include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/reflect/conv_types.hpp" -#include "ck_tile/builder/reflect/instance_traits.hpp" -#include "ck_tile/builder/reflect/instance_traits_util.hpp" #include "ck_tile/builder/types.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/grouped_convolution.hpp" namespace ck_tile::reflect::conv { -// Forward convolution layout concept - checks for A/B/E layout types -template -concept HasFwdConvLayouts = requires { - typename T::ALayout; - typename T::BLayout; - typename T::ELayout; -}; - -// GEMM specialization concept - checks for kGemmSpecialization member -template -concept HasGemmSpec = requires { - { - T::kGemmSpecialization - } -> std::convertible_to; -}; - -// Data types concept - checks for ADataType member -template -concept HasDataTypes = requires { typename T::ADataType; }; - -// Elementwise operations concept - checks for A/B/CDE elementwise operation types -template -concept HasElementwiseOps = requires { - typename T::AElementwiseOperation; - typename T::BElementwiseOperation; - typename T::CDEElementwiseOperation; -}; - -// Tile parameters concept - checks for tile dimension and transfer members -template -concept HasTileParams = requires { - { T::kKPerBlock } -> std::convertible_to; - { T::kMPerBlock } -> std::convertible_to; - { T::kNPerBlock } -> std::convertible_to; - { T::kAK1 } -> std::convertible_to; - { T::kBK1 } -> std::convertible_to; - T::kCThreadClusterLengths; -}; - -// Comprehensive concept that checks if an instance has all XDL forward convolution traits -// This concept is used to constrain ConvTraits specialization that expect XDL forward convolutions -template -concept IsXdlFwdConv = HasFwdConvLayouts && HasGemmSpec && HasDataTypes && - HasElementwiseOps && HasTileParams; - -// Primary concept for checking if a type can be described -// Currently only forward convolutions are supported, but this can be extended -// in the future to include backward data and backward weight convolutions -template -concept HasConvTraits = IsXdlFwdConv>; - -// Helper metafunctions to convert from ck enums to builder enums - -/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). -/// @details This function maps CK's block GEMM pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. The pipeline version -/// determines the strategy used for data movement and computation overlap in the -/// GEMM kernel's main loop. -template -constexpr auto convert_pipeline_version() -{ - using enum ck::BlockGemmPipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v3: return V3; - case v4: return V4; - case v5: return V5; - } -} - -/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. -/// @tparam ck_ver The CK PipelineVersion enum value to convert. -/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). -/// @details This function maps CK's general pipeline version identifiers to the -/// builder framework's standardized pipeline version enum. Note that this overload -/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion -/// variant, including support for specialized weight-only pipelines. -template -constexpr auto convert_pipeline_version() +// Runtime data structure representing a convolution kernel's complete configuration +// +// This pure data struct (no template parameters, no static members) provides +// type erasure for convolution kernel configurations. It can hold the configuration +// from any convolution kernel instance, enabling runtime storage, comparison, and +// manipulation of kernel properties. +// +// The struct is populated by `instance_to_conv_traits()` template specializations +// that extract compile-time configuration from InstanceTraits and convert it to +// this standardized runtime representation. +// +// Members are organized into two categories: +// - **Signature Information**: Defines the computational interface (what to compute) +// - **Algorithm Information**: Defines the implementation strategy (how to compute) +// +// Note: This struct will evolve to support additional convolution variants and +// eventually generalize to other kernel types through composition. +// +// There is a lot we still need to do: +// +// TODO: Generalize type support for all tensors and accumulator. +// TODO: Describe all tensros. +// TODO: Include the full generalization of the signature from the input schema. +// TODO: Include the full generalization of the algorithm from the input schema. +struct ConvTraits { - using enum ck::PipelineVersion; - using enum builder::PipelineVersion; - - switch(ck_ver) - { - case v1: return V1; - case v2: return V2; - case v4: return V4; - case weight_only: return WEIGHT_ONLY; - } -} - -/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). -/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the -/// builder framework's standardized scheduler enum. The scheduler determines how work -/// is distributed and synchronized within and across wavefronts during pipeline execution. -/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates -/// across multiple wavefronts. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::BlockGemmPipelineScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Intrawave: return INTRAWAVE; - case Interwave: return INTERWAVE; - } -} - -/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. -/// @tparam ck_sched The CK LoopScheduler enum value to convert. -/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). -/// @details This function maps CK's loop scheduler identifiers to the builder framework's -/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of -/// the main computational loop are scheduled across threads. DEFAULT uses the standard -/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved -/// performance in certain scenarios. -template -constexpr auto convert_pipeline_scheduler() -{ - using enum ck::LoopScheduler; - using enum builder::PipelineScheduler; - - switch(ck_sched) - { - case Default: return DEFAULT; - case Interwave: return INTERWAVE; - } -} - -// Helper metafunctions to derive signature information from Instance types - -/// @brief Helper function to report unsupported convolution direction with a clear error message. -template -[[noreturn]] consteval void report_unsupported_conv_direction_error() -{ - throw "Unsupported convolution direction detected!\n" - "The kernel instance does not have a recognized convolution specialization.\n" - "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " - "kConvBwdWeightSpecialization.\n" - "Please verify that your kernel instance is properly configured."; -} - -/// @brief Derives the convolution direction from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). -template -constexpr builder::ConvDirection conv_direction() -{ - using InstTraits = InstanceTraits; - - if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) - return builder::ConvDirection::FORWARD; - else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) - return builder::ConvDirection::BACKWARD_DATA; - else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) - return builder::ConvDirection::BACKWARD_WEIGHT; - else - { - report_unsupported_conv_direction_error(); - return builder::ConvDirection::FORWARD; // Unreachable - } -} - -/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return A `builder::ConvSpecialization` enum value. -template -constexpr auto conv_spec() -{ - using InstTraits = InstanceTraits; - using enum builder::ConvSpecialization; - - if constexpr(requires { InstTraits::kConvForwardSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; - switch(InstTraits::kConvForwardSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter3x3: return FILTER_3x3; - case OddC: return ODD_C; - } - } - else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; - switch(InstTraits::kConvBwdDataSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - } - } - else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) - { - using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; - switch(InstTraits::kConvBwdWeightSpecialization) - { - case Default: return DEFAULT; - case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; - case Filter1x1Pad0: return FILTER_1X1_PAD0; - case OddC: return ODD_C; - } - } -} - -// Helper variable template to check if CK layout enums match -template -inline constexpr bool layouts_are = - std::is_same_v && std::is_same_v && std::is_same_v; - -/// @brief Helper function to report unsupported layout combinations with a clear error message. -/// @details This consteval function is designed to fail at compile time with a descriptive -/// error message when an unsupported layout combination is encountered. -template -[[noreturn]] consteval void report_unsupported_layout_error() -{ - // This will produce a compile-time error with the exception message - throw "Unsupported convolution layout combination detected!\n" - "The combination of ALayout, BLayout, and ELayout template parameters\n" - "is not recognized for the given spatial dimension.\n" - "Please verify that your convolution instance uses a supported layout configuration.\n" - "Check the conv_layout() function for the list of supported layout combinations."; -} - -/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array corresponding to the tensor layouts: -/// index 0 -> Input layout -/// index 1 -> Weight layout -/// index 2 -> Output layout -template -constexpr auto conv_layout() - requires HasFwdConvLayouts> -{ - // Helper lambda to construct layout array - auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - - using A = typename InstanceTraits::ALayout; - using B = typename InstanceTraits::BLayout; - using E = typename InstanceTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - switch(InstanceTraits::kSpatialDim) - { - case 1: - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(GNWC, GKXC, GNWK); - if constexpr(layouts_are) - return layouts(NWGC, GKXC, NWGK); - if constexpr(layouts_are) - return layouts(NGCW, GKXC, NGKW); - if constexpr(layouts_are) - return layouts(NGCW, GKCX, NGKW); - break; - case 2: - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(GNHWC, GKYXC, GNHWK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NHWGC, GKYXC, NHWGK); - if constexpr(layouts_are) - return layouts(NGCHW, GKYXC, NGKHW); - if constexpr(layouts_are) - return layouts(NGCHW, GKCYX, NGKHW); - break; - case 3: - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(GNDHWC, GKZYXC, GNDHWK); - if constexpr(layouts_are) - return layouts(NDHWGC, GKZYXC, NDHWGK); - if constexpr(layouts_are) - return layouts(NGCDHW, GKZYXC, NGKDHW); - if constexpr(layouts_are) - return layouts(NGCDHW, GKCZYX, NGKDHW); - break; - } - - // If we reach here, the layout combination is not supported - // Call consteval function to trigger a compile-time error with a clear message - report_unsupported_layout_error::kSpatialDim>(); - - // This return is unreachable but needed to satisfy the compiler - return layouts(GNHWC, GKYXC, GNHWK); -} - -/// @brief Helper function to report unsupported data type with a clear error message. -template -[[noreturn]] consteval void report_unsupported_data_type_error() -{ - throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" - "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " - "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " - "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " - "(BF8), " - "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" - "Please verify that your kernel instance uses a supported data type."; -} - -/// @brief Derives the data type from a device kernel `Instance` type. -/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). -template -constexpr builder::DataType conv_data_type() - requires HasDataTypes> -{ - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; - using enum builder::DataType; - - if constexpr(std::is_same_v) - return FP16; - else if constexpr(std::is_same_v>) - return FP16_FP16; - else if constexpr(std::is_same_v) - return BF16; - else if constexpr(std::is_same_v>) - return BF16_BF16; - else if constexpr(std::is_same_v) - return FP32; - else if constexpr(std::is_same_v>) - return FP32_FP32; - else if constexpr(std::is_same_v) - return FP64; - else if constexpr(std::is_same_v) - return FP8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return BF8; - else if constexpr(std::is_same_v) - return I8; - else if constexpr(std::is_same_v>) - return I8_I8; - else if constexpr(std::is_same_v) - return U8; - else - { - report_unsupported_data_type_error(); - return FP32; // Unreachable - } -} - -/// @brief Helper function to report unsupported elementwise operation with a clear error message. -template -[[noreturn]] consteval void report_unsupported_elementwise_op_error() -{ - throw "Unsupported elementwise operation detected!\n" - "The elementwise operation type is not recognized.\n" - "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " - "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " - "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " - "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " - "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " - "UnaryConvert.\n" - "Please verify that your kernel instance uses a supported elementwise operation."; -} - -/// @brief Derives the elementwise operation from op type. -/// @tparam ElementwiseOp Elementwise operation functor type. -/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. -template -constexpr builder::ElementwiseOperation elementwise_op() -{ - using enum builder::ElementwiseOperation; - constexpr std::string_view name = detail::elementwise_op_name(); - - if constexpr(detail::case_insensitive_equal(name, "AddClamp")) - return ADD_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) - return ADD_RELU_ADD; - else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) - return BILINEAR; - else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) - return BIAS_BNORM_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Clamp")) - return CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) - return CONV_INVSCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) - return CONV_SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) - return CONV_SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) - return CONV_SCALE_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Scale")) - return SCALE; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) - return SCALE_ADD; - else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) - return PASS_THROUGH; - else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) - return SCALEADD_SCALEADD_RELU; - else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) - return DYNAMIC_UNARY_OP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) - return UNARY_COMBINED_OP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) - return ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) - return ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) - return ADD_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) - return ADD_ACTIVATION_MUL2_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) - return ADD_MUL_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) - return ADD_MUL2_ACTIVATION_MUL_CLAMP; - else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) - return UNARY_CONVERT; - else if constexpr(detail::case_insensitive_equal(name, "Logistic")) - return LOGISTIC; - else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) - return CLIPPED_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Swish")) - return SWISH; - else if constexpr(detail::case_insensitive_equal(name, "Elu")) - return ELU; - else if constexpr(detail::case_insensitive_equal(name, "Power")) - return POWER; - else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) - return LEAKY_RELU; - else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) - return UNARY_ABS; - else if constexpr(detail::case_insensitive_equal(name, "Relu")) - return RELU; - else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) - return SOFT_RELU; - else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) - return SIGMOID; - else if constexpr(detail::case_insensitive_equal(name, "TanH")) - return TANH; - else if constexpr(detail::case_insensitive_equal(name, "Gelu")) - return GELU; - else if constexpr(detail::case_insensitive_equal(name, "Silu")) - return SILU; - else - { - report_unsupported_elementwise_op_error(); - return PASS_THROUGH; // Unreachable - } -} - -/// @brief Derives a gemm padding from a kernel instance type. -/// @tparam Instance - A Device Kernel object type. -/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. -template -constexpr builder::GemmPadding gemm_spec() - requires HasGemmSpec> -{ - using InstTraits = InstanceTraits; - using enum builder::GemmPadding; - using enum ck::tensor_operation::device::GemmSpecialization; - - constexpr auto gemm_spec = InstTraits::kGemmSpecialization; - - switch(gemm_spec) - { - case Default: return DEFAULT; - case MPadding: return M_PADDING; - case NPadding: return N_PADDING; - case KPadding: return K_PADDING; - case MNPadding: return MN_PADDING; - case MKPadding: return MK_PADDING; - case NKPadding: return NK_PADDING; - case MNKPadding: return MNK_PADDING; - case OPadding: return O_PADDING; - case MOPadding: return MO_PADDING; - case NOPadding: return NO_PADDING; - case KOPadding: return KO_PADDING; - case MNOPadding: return MNO_PADDING; - case MKOPadding: return MKO_PADDING; - case NKOPadding: return NKO_PADDING; - case MNKOPadding: return MNKO_PADDING; - } -} - -/// @brief Primary template for extracting convolution traits. -/// @details This struct is the main entry point for reflecting on a convolution -/// kernel's properties. It is specialized to handle different kinds of input types. -template -struct ConvTraits; - -/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. -/// @details This is the primary specialization used to extract a comprehensive -/// set of traits directly from a fully-formed device kernel `Instance` type. -/// It uses `InstanceTraits` to access the kernel's template parameters. -template - requires IsXdlFwdConv> -struct ConvTraits -{ - using InstTraits = InstanceTraits; - // --- Signature Information --- - /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). - static constexpr int spatial_dim = InstTraits::kSpatialDim; - /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). - static constexpr builder::ConvDirection direction = conv_direction(); - /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). - static constexpr auto layout = conv_layout(); - /// @brief The primary data type used in the computation (e.g., FP16, FP32). - static constexpr builder::DataType data_type = conv_data_type(); + int spatial_dim; + builder::ConvDirection direction; + std::array layout; // [input, weight, output] + builder::DataType data_type; - static constexpr builder::ElementwiseOperation input_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation weight_element_op = - elementwise_op(); - static constexpr builder::ElementwiseOperation output_element_op = - elementwise_op(); + builder::ElementwiseOperation input_element_op; + builder::ElementwiseOperation weight_element_op; + builder::ElementwiseOperation output_element_op; - /// @brief The GEMM specialization used by the kernel - padding - static constexpr auto gemm_padding = gemm_spec(); - /// @brief The convolution-specific specialization (e.g., Default, 1x1). - static constexpr auto conv_specialization = conv_spec(); + builder::GemmPadding gemm_padding; + builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- - /// @brief The total number of threads in a thread block (workgroup). - static constexpr int thread_block_size = InstTraits::kBlockSize; - /// @brief The dimensions of the data tile processed by the thread block. - static constexpr DataTileInfo tile_dims = { - .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; - - /// @brief Configuration for the A-matrix (input) tile transfer. - static constexpr InputTileTransferInfo a_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; - - /// @brief Configuration for the B-matrix (weights) tile transfer. - static constexpr InputTileTransferInfo b_tile_transfer = { - .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; - - /// @brief Parameters for the warp-level GEMM computation. - static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}; - - /// @brief Configuration for the C-matrix (output) tile transfer. - static constexpr OutputTileTransferInfo c_tile_transfer = { - .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + int thread_block_size; + DataTileInfo tile_dims; - /// @brief Helper to safely get the pipeline version. - /// @details This is only available for some convolutions (e.g., forward). - /// If not present in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_version() - { - if constexpr(requires { T::kPipelineVersion; }) - { - return convert_pipeline_version(); - } - else - { - // Return a default or indicate not available - return builder::PipelineVersion::V1; - } - } + InputTileTransferInfo a_tile_transfer; + InputTileTransferInfo b_tile_transfer; - /// @brief The block GEMM pipeline version used by the kernel. - static constexpr auto pipeline_version = get_pipeline_version(); + WarpGemmParams warp_gemm; - /// @brief Helper to safely get the pipeline scheduler. - /// @details This is only available for some convolutions. If not present - /// in `InstanceTraits`, it returns a default value. - template - static constexpr auto get_pipeline_scheduler() - { - if constexpr(requires { T::kPipelineScheduler; }) - { - return convert_pipeline_scheduler(); - } - else if constexpr(requires { T::kLoopScheduler; }) - { - return convert_pipeline_scheduler(); - } - else - { - // Return a default or indicate not available - return builder::PipelineScheduler::DEFAULT; - } - } + OutputTileTransferInfo c_tile_transfer; - /// @brief The pipeline scheduler used by the kernel. - static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); + builder::PipelineVersion pipeline_version; + builder::PipelineScheduler pipeline_scheduler; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000..cdd238f36a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..28c43c342f --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100644 index 0000000000..c4bed850eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = {.m = InstTraits::kMPerBlock, + .n = InstTraits::kNPerBlock, + .k = InstTraits::kKPerBlock}, + .a_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, + .b_tile_transfer = + {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = + InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, + .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}, + .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = + InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = + InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp new file mode 100644 index 0000000000..46c196e95a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -0,0 +1,739 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/pipeline_enum.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/reflect/conv_types.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_util.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +/// @file conv_traits_helpers.hpp +/// @brief Helper utilities for extracting convolution traits from kernel instances +/// +/// This file provides compile-time reflection utilities to extract configuration +/// information from CK convolution kernel instances and convert them to the builder +/// framework's standardized representation. +/// +/// ## Organization +/// +/// The file is organized into the following sections: +/// +/// 1. **Enum Conversions**: Functions to convert CK enums to builder enums +/// - Pipeline version conversions (BlockGemmPipelineVersion, PipelineVersion) +/// - Pipeline scheduler conversions (BlockGemmPipelineScheduler, LoopScheduler) +/// +/// 2. **Signature Derivation**: Functions to extract signature information from instances +/// - Convolution direction (conv_direction) +/// - Convolution specialization (conv_spec) +/// - Tensor layouts (conv_layout) +/// - Data types (conv_data_type) +/// - Elementwise operations (elementwise_op) +/// - GEMM padding (gemm_spec) +/// +/// 3. **Pipeline Configuration Helpers**: Safe extraction of pipeline parameters +/// - Pipeline version extraction (get_pipeline_version) +/// - Pipeline scheduler extraction (get_pipeline_scheduler) +/// +/// ## Error Handling Strategy +/// +/// This file uses a specific error handling pattern for compile-time errors: +/// - **consteval functions with throw**: Used for error reporting to ensure SFINAE doesn't +/// silently ignore errors. The thrown string becomes part of the compiler error message, +/// providing clear context to developers. +/// - **DO NOT replace with static_assert**: static_assert is silently ignored during SFINAE, +/// which would hide errors instead of reporting them clearly. +/// +/// @example +/// ```cpp +/// using Instance = +/// ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<...>; +/// +/// // Extract convolution direction +/// constexpr auto dir = conv_direction(); +/// +/// // Extract data type +/// constexpr auto dtype = conv_data_type(); +/// +/// // Extract layout configuration +/// constexpr auto layouts = conv_layout(); +/// ``` + +namespace ck_tile::reflect::conv { + +// ============================================================================ +// SECTION 1: ENUM CONVERSIONS +// ============================================================================ + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v3 -> V3 +/// - v4 -> V4 +/// - v5 -> V5 +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v3: return V3; + case v4: return V4; + case v5: return V5; + } +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value. +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +/// +/// Supported mappings: +/// - v1 -> V1 +/// - v2 -> V2 +/// - v4 -> V4 +/// - weight_only -> WEIGHT_ONLY +template +constexpr builder::PipelineVersion convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + + switch(ck_ver) + { + case v1: return V1; + case v2: return V2; + case v4: return V4; + case weight_only: return WEIGHT_ONLY; + } +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// +/// Supported mappings: +/// - Intrawave -> INTRAWAVE: Scheduling within a single wavefront +/// - Interwave -> INTERWAVE: Coordination across multiple wavefronts +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Intrawave: return INTRAWAVE; + case Interwave: return INTERWAVE; + } +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value. +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. +/// +/// Supported mappings: +/// - Default -> DEFAULT: Standard scheduling strategy +/// - Interwave -> INTERWAVE: Cross-wavefront coordination for improved performance +template +constexpr builder::PipelineScheduler convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + + switch(ck_sched) + { + case Default: return DEFAULT; + case Interwave: return INTERWAVE; + } +} + +// ============================================================================ +// SECTION 2: SIGNATURE DERIVATION FUNCTIONS +// ============================================================================ + +// ---------------------------------------------------------------------------- +// Convolution Direction +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution direction with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_direction_error() +{ + throw "Unsupported convolution direction detected!\n" + "The kernel instance does not have a recognized convolution specialization.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution direction from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvDirection enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +/// @details This function inspects the Instance's InstanceTraits to determine which +/// convolution specialization field is present, and returns the corresponding direction. +/// +/// The function checks for the presence of: +/// - kConvForwardSpecialization -> FORWARD +/// - kConvBwdDataSpecialization -> BACKWARD_DATA +/// - kConvBwdWeightSpecialization -> BACKWARD_WEIGHT +/// +/// @note Compilation will fail with a clear error message if the instance does not +/// have a recognized convolution specialization field. +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + return builder::ConvDirection::FORWARD; + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + return builder::ConvDirection::BACKWARD_DATA; + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + return builder::ConvDirection::BACKWARD_WEIGHT; + else + { + report_unsupported_conv_direction_error(); + return builder::ConvDirection::FORWARD; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Convolution Specialization +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported convolution specialization with a clear error +/// message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_conv_spec_error() +{ + throw "Unsupported convolution specialization detected!\n" + "The kernel instance does not have a recognized convolution specialization field.\n" + "Expected one of: kConvForwardSpecialization, kConvBwdDataSpecialization, or " + "kConvBwdWeightSpecialization.\n" + "Please verify that your kernel instance is properly configured."; +} + +/// @brief Derives the convolution-specific specialization from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::ConvSpecialization enum value. +/// @details This function extracts the specialization enum from the Instance's InstanceTraits +/// and converts it to the corresponding builder framework enum. +/// +/// For forward convolutions, supported specializations include: +/// - Default, Filter1x1Pad0, Filter1x1Stride1Pad0, Filter3x3, OddC +/// +/// For backward data convolutions: +/// - Default, Filter1x1Stride1Pad0 +/// +/// For backward weight convolutions: +/// - Default, Filter1x1Stride1Pad0, Filter1x1Pad0, OddC +template +constexpr builder::ConvSpecialization conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvForwardSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter3x3: return FILTER_3x3; + case OddC: return ODD_C; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdDataSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + using enum builder::ConvSpecialization; + + switch(InstTraits::kConvBwdWeightSpecialization) + { + case Default: return DEFAULT; + case Filter1x1Stride1Pad0: return FILTER_1X1_STRIDE1_PAD0; + case Filter1x1Pad0: return FILTER_1X1_PAD0; + case OddC: return ODD_C; + } + } + else + { + report_unsupported_conv_spec_error(); + return builder::ConvSpecialization::DEFAULT; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Tensor Layouts +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported layout combinations with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_layout_error() +{ + throw "Unsupported convolution layout combination detected!\n" + "The combination of ALayout, BLayout, and ELayout template parameters\n" + "is not recognized for the given spatial dimension.\n" + "Please verify that your convolution instance uses a supported layout configuration.\n" + "Check the conv_layout() function for the list of supported layout combinations."; +} + +/// @brief Derives the grouped convolution layout from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array containing the layouts for: +/// - [0] Input tensor layout +/// - [1] Weight tensor layout +/// - [2] Output tensor layout +/// @details This function examines the Instance's ALayout, BLayout, and ELayout types +/// along with the spatial dimension to determine the appropriate layout configuration. +/// +/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). +/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. +/// +/// @note Compilation will fail with a clear error message if the layout combination +/// is not supported for the given spatial dimension. +/// +/// TODO: If we don't check for supported layouts, this function can be simplified. +template +constexpr std::array conv_layout() +{ + using InstTraits = InstanceTraits; + using A = typename InstTraits::ALayout; + using B = typename InstTraits::BLayout; + using E = typename InstTraits::ELayout; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; + + // Helper to check if layouts match expected types + constexpr auto layouts_match = []() { + return std::is_same_v && std::is_same_v && std::is_same_v; + }; + + // Helper to construct layout array + constexpr auto make_layouts = [](auto in, auto weight, auto out) { + return std::array{in, weight, out}; + }; + + constexpr int spatial_dim = InstTraits::kSpatialDim; + + if constexpr(spatial_dim == 1) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNWC, GKXC, GNWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NWGC, GKXC, NWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKXC, NGKW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCW, GKCX, NGKW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNWC, GKXC, GNWK); // Unreachable + } + } + else if constexpr(spatial_dim == 2) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNHWC, GKYXC, GNHWK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NHWGC, GKYXC, NHWGK); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKYXC, NGKHW); + else if constexpr(layouts_match.template operator()()) + return make_layouts(NGCHW, GKCYX, NGKHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } + } + else if constexpr(spatial_dim == 3) + { + if constexpr(layouts_match.template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(GNDHWC, GKZYXC, GNDHWK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NDHWGC, GKZYXC, NDHWGK); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKZYXC, NGKDHW); + else if constexpr(layouts_match + .template operator()()) + return make_layouts(NGCDHW, GKCZYX, NGKDHW); + else + { + report_unsupported_layout_error(); + return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable + } + } + else + { + report_unsupported_layout_error(); + return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Data Types +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported data type with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_data_type_error() +{ + throw "Unsupported data type detected!\n" + "The ADataType is not recognized.\n" + "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " + "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " + "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " + "(BF8), " + "int8_t (I8), ck::Tuple (I8_I8), uint8_t (U8).\n" + "Please verify that your kernel instance uses a supported data type."; +} + +/// @brief Derives the data type from a device kernel Instance type. +/// @tparam Instance The device kernel instance type. +/// @return A builder::DataType enum value representing the input data type. +/// @details This function examines the Instance's ADataType to determine the data type +/// used for the input tensor. The function supports various floating-point and integer +/// types, including tuple types for mixed-precision operations. +/// +/// Supported data types include: +/// - FP16 (ck::half_t) +/// - FP16_FP16 (ck::Tuple) +/// - BF16 (ck::bhalf_t) +/// - BF16_BF16 (ck::Tuple) +/// - FP32 (float) +/// - FP32_FP32 (ck::Tuple) +/// - FP64 (double) +/// - FP8 (ck::f8_t) +/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) +/// - I8 (int8_t) +/// - I8_I8 (ck::Tuple) +/// - U8 (uint8_t) +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + using enum builder::DataType; + + if constexpr(std::is_same_v) + return FP16; + else if constexpr(std::is_same_v>) + return FP16_FP16; + else if constexpr(std::is_same_v) + return BF16; + else if constexpr(std::is_same_v>) + return BF16_BF16; + else if constexpr(std::is_same_v) + return FP32; + else if constexpr(std::is_same_v>) + return FP32_FP32; + else if constexpr(std::is_same_v) + return FP64; + else if constexpr(std::is_same_v) + return FP8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return BF8; + else if constexpr(std::is_same_v) + return I8; + else if constexpr(std::is_same_v>) + return I8_I8; + else if constexpr(std::is_same_v) + return U8; + else + { + report_unsupported_data_type_error(); + return FP32; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// Elementwise Operations +// ---------------------------------------------------------------------------- + +/// @brief Helper function to report unsupported elementwise operation with a clear error message. +/// @details This consteval function uses throw (not static_assert) to ensure the error is not +/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +template +[[noreturn]] consteval void report_unsupported_elementwise_op_error() +{ + throw "Unsupported elementwise operation detected!\n" + "The elementwise operation type is not recognized.\n" + "Supported operations are: AddClamp, AddReluAdd, BiasBnormClamp, Bilinear, " + "BiasNormalizeInInferClamp, Clamp, ConvInvscale, ConvScale, ConvScaleAdd, " + "ConvScaleRelu, Scale, ScaleAdd, PassThrough, ScaleAddScaleAddRelu, DynamicUnaryOp, " + "UnaryCombinedOp, Activation_Mul2_Clamp, Activation_Mul_Clamp, Add_Activation_Mul_Clamp, " + "Add_Activation_Mul2_Clamp, Add_Mul_Activation_Mul_Clamp, Add_Mul2_Activation_Mul_Clamp, " + "UnaryConvert.\n" + "Please verify that your kernel instance uses a supported elementwise operation."; +} + +/// @brief Derives the elementwise operation from an operation functor type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A builder::ElementwiseOperation enum value corresponding to the operation. +/// @details This function uses the operation's type name to determine which elementwise +/// operation is being used. The comparison is case-insensitive. +/// +/// Supported operations include: +/// - Activation operations: Relu, Sigmoid, Tanh, Gelu, Silu, Elu, Swish, etc. +/// - Scaling operations: Scale, ScaleAdd, ConvScale, ConvScaleAdd, etc. +/// - Clamping operations: Clamp, AddClamp, etc. +/// - Combined operations: Add_Activation_Mul_Clamp, etc. +/// - Utility operations: PassThrough, UnaryConvert, etc. +/// +/// TODO: Consider changing this to direct checks on the types, not strings. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + using enum builder::ElementwiseOperation; + constexpr std::string_view name = detail::elementwise_op_name(); + + if constexpr(detail::case_insensitive_equal(name, "AddClamp")) + return ADD_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "AddReluAdd")) + return ADD_RELU_ADD; + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + return BILINEAR; + else if constexpr(detail::case_insensitive_equal(name, "BiasNormalizeInInferClamp")) + return BIAS_BNORM_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + return CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "ConvInvscale")) + return CONV_INVSCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScale")) + return CONV_SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleAdd")) + return CONV_SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "ConvScaleRelu")) + return CONV_SCALE_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + return SCALE; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAdd")) + return SCALE_ADD; + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + return PASS_THROUGH; + else if constexpr(detail::case_insensitive_equal(name, "ScaleAddScaleAddRelu")) + return SCALEADD_SCALEADD_RELU; + else if constexpr(detail::case_insensitive_equal(name, "DynamicUnaryOp")) + return DYNAMIC_UNARY_OP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryCombinedOp")) + return UNARY_COMBINED_OP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul2_Clamp")) + return ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Activation_Mul_Clamp")) + return ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul_Clamp")) + return ADD_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Activation_Mul2_Clamp")) + return ADD_ACTIVATION_MUL2_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul_Activation_Mul_Clamp")) + return ADD_MUL_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "Add_Mul2_Activation_Mul_Clamp")) + return ADD_MUL2_ACTIVATION_MUL_CLAMP; + else if constexpr(detail::case_insensitive_equal(name, "UnaryConvert")) + return UNARY_CONVERT; + else if constexpr(detail::case_insensitive_equal(name, "Logistic")) + return LOGISTIC; + else if constexpr(detail::case_insensitive_equal(name, "ClippedRelu")) + return CLIPPED_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Swish")) + return SWISH; + else if constexpr(detail::case_insensitive_equal(name, "Elu")) + return ELU; + else if constexpr(detail::case_insensitive_equal(name, "Power")) + return POWER; + else if constexpr(detail::case_insensitive_equal(name, "LeakyRelu")) + return LEAKY_RELU; + else if constexpr(detail::case_insensitive_equal(name, "UnaryAbs")) + return UNARY_ABS; + else if constexpr(detail::case_insensitive_equal(name, "Relu")) + return RELU; + else if constexpr(detail::case_insensitive_equal(name, "SoftRelu")) + return SOFT_RELU; + else if constexpr(detail::case_insensitive_equal(name, "Sigmoid")) + return SIGMOID; + else if constexpr(detail::case_insensitive_equal(name, "TanH")) + return TANH; + else if constexpr(detail::case_insensitive_equal(name, "Gelu")) + return GELU; + else if constexpr(detail::case_insensitive_equal(name, "Silu")) + return SILU; + else + { + report_unsupported_elementwise_op_error(); + return PASS_THROUGH; // Unreachable + } +} + +// ---------------------------------------------------------------------------- +// GEMM Padding +// ---------------------------------------------------------------------------- + +/// @brief Derives the GEMM padding specification from a kernel instance type. +/// @tparam Instance A device kernel instance type. +/// @return A builder::GemmPadding enum value corresponding to the kernel's padding configuration. +/// @details This function extracts the GEMM specialization from the Instance's InstanceTraits +/// and converts it to the builder framework's GemmPadding enum. The padding specification +/// indicates which dimensions (M, N, K, O) are padded to handle non-aligned tensor sizes. +/// +/// Supported padding configurations include: +/// - DEFAULT: No padding +/// - M_PADDING, N_PADDING, K_PADDING, O_PADDING: Single dimension padding +/// - MN_PADDING, MK_PADDING, NK_PADDING, etc.: Two dimension padding +/// - MNK_PADDING, MNO_PADDING, etc.: Three dimension padding +/// - MNKO_PADDING: All dimensions padded +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto spec = InstTraits::kGemmSpecialization; + + switch(spec) + { + case Default: return DEFAULT; + case MPadding: return M_PADDING; + case NPadding: return N_PADDING; + case KPadding: return K_PADDING; + case MNPadding: return MN_PADDING; + case MKPadding: return MK_PADDING; + case NKPadding: return NK_PADDING; + case MNKPadding: return MNK_PADDING; + case OPadding: return O_PADDING; + case MOPadding: return MO_PADDING; + case NOPadding: return NO_PADDING; + case KOPadding: return KO_PADDING; + case MNOPadding: return MNO_PADDING; + case MKOPadding: return MKO_PADDING; + case NKOPadding: return NKO_PADDING; + case MNKOPadding: return MNKO_PADDING; + } +} + +// ============================================================================ +// SECTION 3: PIPELINE CONFIGURATION HELPERS +// ============================================================================ + +/// @brief Safely extracts the pipeline version from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline version from. +/// @return The pipeline version as a builder::PipelineVersion enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineVersion +/// field and extracts it if present. If not present, it returns a default value (V1). +/// This is necessary because not all convolution types expose pipeline version information. +template +constexpr builder::PipelineVersion get_pipeline_version() +{ + if constexpr(requires { InstTraits::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + return builder::PipelineVersion::V1; + } +} + +/// @brief Safely extracts the pipeline scheduler from InstanceTraits. +/// @tparam InstTraits The InstanceTraits type to extract pipeline scheduler from. +/// @return The pipeline scheduler as a builder::PipelineScheduler enum value. +/// @details This helper function checks if the InstanceTraits has a kPipelineScheduler +/// or kLoopScheduler field and extracts it if present. If neither is present, it returns +/// a default value (DEFAULT). This is necessary because different convolution types may +/// expose scheduler information through different field names. +template +constexpr builder::PipelineScheduler get_pipeline_scheduler() +{ + if constexpr(requires { InstTraits::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { InstTraits::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + return builder::PipelineScheduler::DEFAULT; + } +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp new file mode 100644 index 0000000000..00010e2d48 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -0,0 +1,8 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index f5f3df3159..71db59afb6 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -74,6 +74,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index ace1b09224..4549b76a3f 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -78,6 +78,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 09274d5acd..046e5c3078 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -73,6 +73,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; namespace ck_tile::reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor device kernel +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag +{ +}; + // Specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor template > { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Tag; + // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index ddcf8db476..9890563859 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -108,7 +108,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility) # Tests convolution trait selection and configuration add_ck_builder_test(test_ckb_conv_traits conv/ck/test_conv_traits.cpp - conv/ck/unit_instance_to_conv_traits.cpp) + conv/ck/unit_instance_to_conv_traits_features.cpp + conv/ck/unit_instance_to_conv_traits_instances.cpp) # Tests convolution problem description and parameter handling add_ck_builder_test(test_ckb_conv_description diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index b3a76e4e11..42235df2fe 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) ck::half_t, // BComputeDataType false>; // DirectLoad - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); // Verify A tile transfer info - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); // Verify B tile transfer info - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); - EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); - EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); - EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); - EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); // Verify warp GEMM params - EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); - EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); - EXPECT_EQ(Traits::warp_gemm.m_iter, 4); - EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); // Verify output tile transfer info - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); - EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); - EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); - EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); - EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle @@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) ck::LoopScheduler::Default, // LoopSched 1>; // NumGroupsToMerge - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } // Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) @@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) ck::half_t, // BComputeDataType ck::LoopScheduler::Default>; // LoopSched - // Use ConvTraits to extract compile-time information - using Traits = ck_tile::reflect::conv::ConvTraits; + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); // Verify signature information - EXPECT_EQ(Traits::spatial_dim, 2); - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); - EXPECT_THAT(Traits::layout, + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); - EXPECT_EQ(Traits::data_type, DataType::FP16); - EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH); - EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); // Verify algorithm information - EXPECT_EQ(Traits::thread_block_size, 256); + EXPECT_EQ(traits.thread_block_size, 256); // Verify tile dimensions - EXPECT_EQ(Traits::tile_dims.m, 128); - EXPECT_EQ(Traits::tile_dims.n, 128); - EXPECT_EQ(Traits::tile_dims.k, 16); + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); } } // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp deleted file mode 100644 index 9d6fab19d1..0000000000 --- a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits.cpp +++ /dev/null @@ -1,1127 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -// ============================================================================ -// Unit Tests for InstanceTraits to ConvTraits Conversion -// ============================================================================ -// -// PURPOSE: -// -------- -// These tests verify the conversion layer between InstanceTraits (low-level -// template parameter extraction) and ConvTraits (high-level semantic traits). -// The conversion transforms raw CK kernel parameters into builder-friendly -// enums and structures. -// -// DESIGN RATIONALE: -// ----------------- -// ConvTraits uses a single generic specialization that works with any Device -// class satisfying the IsXdlFwdConv concept. This use of concepts is fragile -// and introduces extra complexity. We want to refector to just use functions -// for this conversion. -// -// These tests are intentionally verbose and repetitive to provide maximum -// coverage during refactoring. Once the refactoring is complete and stable, -// they can be simplified or consolidated. -// -// TEST COVERAGE: -// -------------- -// 1. Enum conversion functions (pipeline version, scheduler, etc.) -// 2. Signature extraction (direction, specialization, layout, data type) -// 3. Full transformation verification for each XDL Device class template: -// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -// - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle -// - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor -// -// NOTE: WMMA and DL (Direct Load) variants are not covered as they don't -// satisfy the IsXdlFwdConv concept (different tile parameter structure). -// ============================================================================ - -#include "ck/utility/scheduler_enum.hpp" -#include "ck_tile/builder/types.hpp" -#include -#include - -#include -#include -#include -#include -#include - -namespace { - -using ck_tile::builder::ConvDirection; -using ck_tile::builder::DataType; -using ck_tile::builder::ElementwiseOperation; -using ck_tile::builder::GemmPadding; -using ck_tile::builder::PipelineScheduler; -using ck_tile::builder::PipelineVersion; -using ck_tile::builder::TensorLayout; -using ::testing::ElementsAre; - -// ============================================================================ -// Test Enum Conversion Functions -// ============================================================================ - -TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) -{ - using ck_tile::reflect::conv::convert_pipeline_version; - using enum ::ck::BlockGemmPipelineVersion; - using enum ::ck_tile::builder::PipelineVersion; - - EXPECT_EQ(convert_pipeline_version(), V1); - EXPECT_EQ(convert_pipeline_version(), V2); - EXPECT_EQ(convert_pipeline_version(), V3); - EXPECT_EQ(convert_pipeline_version(), V4); - EXPECT_EQ(convert_pipeline_version(), V5); -} - -TEST(InstanceToConvTraits, ConvertsPipelineVersion) -{ - using ck_tile::reflect::conv::convert_pipeline_version; - using enum ck::PipelineVersion; - using enum PipelineVersion; - - EXPECT_EQ(convert_pipeline_version(), V1); - EXPECT_EQ(convert_pipeline_version(), V2); - EXPECT_EQ(convert_pipeline_version(), V4); - EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); -} - -TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) -{ - using ck_tile::reflect::conv::convert_pipeline_scheduler; - using enum ck::BlockGemmPipelineScheduler; - using enum PipelineScheduler; - - EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); - EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); -} - -TEST(InstanceToConvTraits, ConvertsLoopScheduler) -{ - using ck_tile::reflect::conv::convert_pipeline_scheduler; - using enum ck::LoopScheduler; - using enum PipelineScheduler; - - EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); - EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); -} - -// ============================================================================ -// Test Convolution Direction Detection -// ============================================================================ - -TEST(InstanceToConvTraits, DetectsForwardDirection) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::direction, ConvDirection::FORWARD); -} - -// ============================================================================ -// Test Convolution Specialization Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); -} - -TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); -} - -// ============================================================================ -// Test Layout Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsGnhwcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); -} - -TEST(InstanceToConvTraits, ExtractsNhwgcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NHWGC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::NHWGK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); -} - -TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NGCHW, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::NGKHW, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); -} - -TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::NGCHW, - ck::tensor_layout::convolution::GKCYX, - ck::Tuple<>, - ck::tensor_layout::convolution::NGKHW, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_THAT(Traits::layout, - ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); -} - -// ============================================================================ -// Test Data Type Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsFp16DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::FP16); -} - -TEST(InstanceToConvTraits, ExtractsBf16DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::bhalf_t, - ck::bhalf_t, - float, - ck::bhalf_t, - ck::Tuple<>, - ck::bhalf_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::bhalf_t, - ck::bhalf_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::BF16); -} - -TEST(InstanceToConvTraits, ExtractsFp32DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - float, - float, - float, - float, - ck::Tuple<>, - float, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - float, - float, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::FP32); -} - -TEST(InstanceToConvTraits, ExtractsI8DataType) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - int8_t, - int8_t, - int32_t, - int8_t, - ck::Tuple<>, - int8_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - int8_t, - int8_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::data_type, DataType::I8); -} - -// ============================================================================ -// Test GEMM Padding Detection -// ============================================================================ - -TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::gemm_padding, GemmPadding::DEFAULT); -} - -TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::Tuple<>, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - float, - ck::half_t, - ck::Tuple<>, - ck::half_t, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::MNKPadding, - 256, - 128, - 128, - 16, - 8, - 8, - 32, - 32, - 4, - 4, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - ck::Sequence<4, 64, 1>, - ck::Sequence<1, 0, 2>, - ck::Sequence<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - ck::Sequence<1, 32, 1, 8>, - 8, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::half_t, - ck::half_t, - false>; - - using Traits = ck_tile::reflect::conv::ConvTraits; - - EXPECT_EQ(Traits::gemm_padding, GemmPadding::MNK_PADDING); -} - -// ============================================================================ -// Comprehensive Transformation Tests - Per Device Class Template -// ============================================================================ -// These tests verify the complete InstanceTraits → ConvTraits transformation -// for each forward convolution Device class template. They are verbose to -// provide maximum safety during refactoring. -// ============================================================================ - -TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched - ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - false>; // DirectLoad - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::INTRAWAVE); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 1, // NumGemmKPrefetchStage - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - ck::LoopScheduler::Default, // LoopSched - 1>; // NumGroupsToMerge - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) -{ - using DeviceInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - 2, // NDimSpatial - ck::tensor_layout::convolution::GNHWC, // ALayout - ck::tensor_layout::convolution::GKYXC, // BLayout - ck::Tuple<>, // DsLayout - ck::tensor_layout::convolution::GNHWK, // ELayout - ck::half_t, // ADataType - ck::half_t, // BDataType - float, // AccDataType - ck::half_t, // CShuffleDataType - ck::Tuple<>, // DsDataType - ck::half_t, // EDataType - ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation - ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, - ck::tensor_operation::device::GemmSpecialization::Default, - 1, // NumGemmKPrefetchStage - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 4, // NXdlPerWave - ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths - ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths - 8, // CDEBlockTransferScalarPerVector_NPerBlock - ck::half_t, // AComputeDataType - ck::half_t, // BComputeDataType - ck::LoopScheduler::Default>; // LoopSched - - using InstTraits = ck_tile::reflect::InstanceTraits; - using ConvTraits = ck_tile::reflect::conv::ConvTraits; - - // Verify signature information - EXPECT_EQ(ConvTraits::spatial_dim, InstTraits::kSpatialDim); - EXPECT_EQ(ConvTraits::direction, ConvDirection::FORWARD); - EXPECT_EQ(ConvTraits::data_type, DataType::FP16); - EXPECT_EQ(ConvTraits::gemm_padding, GemmPadding::DEFAULT); - - // Verify tile dimensions - EXPECT_EQ(ConvTraits::tile_dims.m, InstTraits::kMPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.n, InstTraits::kNPerBlock); - EXPECT_EQ(ConvTraits::tile_dims.k, InstTraits::kKPerBlock); - - // Verify pipeline configuration - EXPECT_EQ(ConvTraits::pipeline_scheduler, PipelineScheduler::DEFAULT); - EXPECT_EQ(ConvTraits::pipeline_version, PipelineVersion::V1); -} - -} // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp new file mode 100644 index 0000000000..72269c38ac --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_features.cpp @@ -0,0 +1,800 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Individual Conversion Functions +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify individual conversion and extraction functions that +// transform raw CK kernel parameters into semantic types. Each test +// focuses on a single conversion function to ensure it correctly maps +// CK types to builder enums and structures. +// +// TEST COVERAGE: +// -------------- +// 1. Enum Conversions: +// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion) +// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler) +// +// 2. Elementwise Operations (14 operations): +// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd +// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd +// +// 3. Convolution Properties: +// - Direction detection (Forward) +// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0, +// Filter3x3, OddC) +// +// 4. Layout Detection: +// - 1D layouts (GNWC, NWGC, NGCW) +// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX) +// - 3D layouts (GNDHWC, NDHWGC, NGCDHW) +// +// 5. Data Type Detection: +// - FP16, BF16, FP32, I8 +// +// 6. Pipeline Configuration: +// - Pipeline versions (V2, V3) +// - Schedulers (Interwave) +// +// 7. GEMM Padding Variations (17 types): +// - Default, MNK, M, N, K, MN, MK, NK +// - O, MO, NO, KO, MNO, MKO, NKO, MNKO +// ============================================================================ + +#include "ck/utility/scheduler_enum.hpp" +#include "ck_tile/builder/types.hpp" +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::ElementwiseOperation; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; +using ::ck_tile::builder::TensorLayout; +using ::testing::ElementsAre; + +// ============================================================================ +// Test Helper Templates +// ============================================================================ +// These templates reduce boilerplate by providing sensible defaults for +// template parameters that don't vary in most tests. +// ============================================================================ + +namespace defaults { +// Default values used across most tests +static constexpr int kBlockSize = 256; +static constexpr int kMPerBlock = 128; +static constexpr int kNPerBlock = 128; +static constexpr int kKPerBlock = 16; +static constexpr int kAK1 = 8; +static constexpr int kBK1 = 8; +static constexpr int kMPerXDL = 32; +static constexpr int kNPerXDL = 32; +static constexpr int kMXdlPerWave = 4; +static constexpr int kNXdlPerWave = 4; +static constexpr int kABlockTransferSrcVectorDim = 2; +static constexpr int kABlockTransferSrcScalarPerVector = 8; +static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8; +static constexpr int kABlockLdsExtraM = 1; +static constexpr int kBBlockTransferSrcVectorDim = 2; +static constexpr int kBBlockTransferSrcScalarPerVector = 8; +static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8; +static constexpr int kBBlockLdsExtraN = 1; +static constexpr int kCShuffleMXdlPerWavePerShuffle = 1; +static constexpr int kCShuffleNXdlPerWavePerShuffle = 1; +static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8; +static constexpr bool kDirectLoad = false; + +using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>; +using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>; +using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>; +using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>; +} // namespace defaults + +// DeviceInstanceForTests - V3 variant with sensible defaults +template +using DeviceInstanceForTests_V3 = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + NDimSpatial, + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + AccDataType, + ADataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + defaults::kBlockSize, + defaults::kMPerBlock, + defaults::kNPerBlock, + defaults::kKPerBlock, + defaults::kAK1, + defaults::kBK1, + defaults::kMPerXDL, + defaults::kNPerXDL, + defaults::kMXdlPerWave, + defaults::kNXdlPerWave, + defaults::DefaultABlockTransferThreadClusterLengths, + defaults::DefaultABlockTransferThreadClusterArrangeOrder, + defaults::DefaultABlockTransferSrcAccessOrder, + defaults::kABlockTransferSrcVectorDim, + defaults::kABlockTransferSrcScalarPerVector, + defaults::kABlockTransferDstScalarPerVector_AK1, + defaults::kABlockLdsExtraM, + defaults::DefaultBBlockTransferThreadClusterLengths, + defaults::DefaultBBlockTransferThreadClusterArrangeOrder, + defaults::DefaultBBlockTransferSrcAccessOrder, + defaults::kBBlockTransferSrcVectorDim, + defaults::kBBlockTransferSrcScalarPerVector, + defaults::kBBlockTransferDstScalarPerVector_BK1, + defaults::kBBlockLdsExtraN, + defaults::kCShuffleMXdlPerWavePerShuffle, + defaults::kCShuffleNXdlPerWavePerShuffle, + defaults::DefaultCDEBlockTransferClusterLengths, + defaults::kCDEBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ADataType, + BDataType, + defaults::kDirectLoad>; + +// Test case helper for specialization testing +template +using SpecializationTestInstance = + DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + Spec>; + +// Test case helper for layout testing (1D, 2D, 3D) +template +using LayoutTestInstance = DeviceInstanceForTests_V3; + +// Test case helper for data type testing +template +using DataTypeTestInstance = DeviceInstanceForTests_V3<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + DataType, + DataType, + DataType, + AccDataType>; + +// Test case helper for pipeline version testing +template +using PipelineVersionTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + ck::BlockGemmPipelineScheduler::Intrawave, + PipelineVer>; + +// Test case helper for pipeline scheduler testing +template +using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + Scheduler>; + +// Test case helper for GEMM padding testing +template +using GemmPaddingTestInstance = DeviceInstanceForTests_V3< + 2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK, + ck::half_t, + ck::half_t, + ck::half_t, + float, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + GemmSpec>; + +// ============================================================================ +// Test Enum Conversion Functions +// ============================================================================ + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ::ck::BlockGemmPipelineVersion; + using enum ::ck_tile::builder::PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V3); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), V5); +} + +TEST(InstanceToConvTraits, ConvertsPipelineVersion) +{ + using ck_tile::reflect::conv::convert_pipeline_version; + using enum ck::PipelineVersion; + using enum PipelineVersion; + EXPECT_EQ(convert_pipeline_version(), V1); + EXPECT_EQ(convert_pipeline_version(), V2); + EXPECT_EQ(convert_pipeline_version(), V4); + EXPECT_EQ(convert_pipeline_version(), WEIGHT_ONLY); +} + +TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::BlockGemmPipelineScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), INTRAWAVE); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +TEST(InstanceToConvTraits, ConvertsLoopScheduler) +{ + using ck_tile::reflect::conv::convert_pipeline_scheduler; + using enum ck::LoopScheduler; + using enum PipelineScheduler; + EXPECT_EQ(convert_pipeline_scheduler(), DEFAULT); + EXPECT_EQ(convert_pipeline_scheduler(), INTERWAVE); +} + +// ============================================================================ +// Test Elementwise Operations +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPassThroughOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, PASS_THROUGH); +} + +TEST(InstanceToConvTraits, ExtractsScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE); +} + +TEST(InstanceToConvTraits, ExtractsReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, RELU); +} + +TEST(InstanceToConvTraits, ExtractsGeluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, GELU); +} + +TEST(InstanceToConvTraits, ExtractsSigmoidOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SIGMOID); +} + +TEST(InstanceToConvTraits, ExtractsTanhOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, TANH); +} + +TEST(InstanceToConvTraits, ExtractsScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SCALE_ADD); +} + +TEST(InstanceToConvTraits, ExtractsSiluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SILU); +} + +TEST(InstanceToConvTraits, ExtractsSwishOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, SWISH); +} + +TEST(InstanceToConvTraits, ExtractsEluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, ELU); +} + +TEST(InstanceToConvTraits, ExtractsLeakyReluOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, LEAKY_RELU); +} + +TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, UNARY_CONVERT); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE); +} + +TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation) +{ + using enum ElementwiseOperation; + constexpr auto op = + ck_tile::reflect::conv::elementwise_op(); + EXPECT_EQ(op, CONV_SCALE_ADD); +} + +// ============================================================================ +// Test Convolution Direction Detection +// ============================================================================ + +TEST(InstanceToConvTraits, DetectsForwardDirection) +{ + using DeviceInstance = DeviceInstanceForTests_V3<>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); +} + +// ============================================================================ +// Test Convolution Specialization Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, + ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0); +} + +TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3); +} + +TEST(InstanceToConvTraits, ExtractsOddCSpecialization) +{ + using DeviceInstance = SpecializationTestInstance< + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C); +} + +// ============================================================================ +// Test 1D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnwcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::GNWC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::GNWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK)); +} + +TEST(InstanceToConvTraits, ExtractsNwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NWGC, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcwLayout) +{ + using DeviceInstance = LayoutTestInstance<1, + ck::tensor_layout::convolution::NGCW, + ck::tensor_layout::convolution::GKXC, + ck::tensor_layout::convolution::NGKW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 1); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW)); +} + +// ============================================================================ +// Test 2D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGnhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::GNHWC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::GNHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NHWGC, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKYXC, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW)); +} + +TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout) +{ + using DeviceInstance = LayoutTestInstance<2, + ck::tensor_layout::convolution::NGCHW, + ck::tensor_layout::convolution::GKCYX, + ck::tensor_layout::convolution::NGKHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW)); +} + +// ============================================================================ +// Test 3D Convolution Layout Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsGndhwcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); +} + +TEST(InstanceToConvTraits, ExtractsNdhwgcLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NDHWGC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NDHWGK>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK)); +} + +TEST(InstanceToConvTraits, ExtractsNgcdhwLayout) +{ + using DeviceInstance = LayoutTestInstance<3, + ck::tensor_layout::convolution::NGCDHW, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::NGKDHW>; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW)); +} + +// ============================================================================ +// Test Data Type Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsFp16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP16); +} + +TEST(InstanceToConvTraits, ExtractsBf16DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::BF16); +} + +TEST(InstanceToConvTraits, ExtractsFp32DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::FP32); +} + +TEST(InstanceToConvTraits, ExtractsI8DataType) +{ + using DeviceInstance = DataTypeTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.data_type, DataType::I8); +} + +// ============================================================================ +// Test Pipeline Version Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsPipelineV2) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2); +} + +TEST(InstanceToConvTraits, ExtractsPipelineV3) +{ + using DeviceInstance = PipelineVersionTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3); +} + +TEST(InstanceToConvTraits, ExtractsInterwaveScheduler) +{ + using DeviceInstance = PipelineSchedulerTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE); +} + +// ============================================================================ +// Test GEMM Padding Detection +// ============================================================================ + +TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); +} + +TEST(InstanceToConvTraits, ExtractsMnkGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING); +} + +TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding) +{ + using DeviceInstance = + GemmPaddingTestInstance; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING); +} + +} // anonymous namespace diff --git a/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp new file mode 100644 index 0000000000..38942f9d45 --- /dev/null +++ b/experimental/builder/test/conv/ck/unit_instance_to_conv_traits_instances.cpp @@ -0,0 +1,262 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// ============================================================================ +// Unit Tests for Complete Device Instance Transformations +// ============================================================================ +// +// PURPOSE: +// -------- +// These tests verify the complete instance_to_conv_traits transformation +// for entire Device class templates. Each test validates that all traits +// are correctly extracted from a specific Device class instantiation. +// +// TEST COVERAGE: +// -------------- +// Complete transformation verification for each XDL Device class template: +// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +// +// Each test verifies: +// - Spatial dimension extraction +// - Convolution direction +// - Data type detection +// - GEMM padding configuration +// - Tile dimensions (M, N, K per block) +// - Pipeline scheduler and version +// ============================================================================ + +#include + +#include +#include +#include +#include +#include + +namespace { + +using ::ck_tile::builder::ConvDirection; +using ::ck_tile::builder::DataType; +using ::ck_tile::builder::GemmPadding; +using ::ck_tile::builder::PipelineScheduler; +using ::ck_tile::builder::PipelineVersion; + +// ============================================================================ +// Comprehensive Transformation Tests - Per Device Class Template +// ============================================================================ +// These tests verify the complete InstanceTraits → ConvTraits transformation +// for each forward convolution Device class template. +// ============================================================================ + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler) + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + ck::tensor_operation::device::GemmSpecialization::Default, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + using InstTraits = ck_tile::reflect::InstanceTraits; + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + // Verify signature information + EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT); + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock); + EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock); + EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock); + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +} // anonymous namespace diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 7cb0ae20c3..cc343f6f69 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -2108,7 +2108,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle std::unique_ptr describe() const override { - static_assert(ck_tile::reflect::conv::HasConvTraits, + static_assert(ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 4f410d0cce..c9fb8ca3f6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -1282,7 +1282,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor std::unique_ptr describe() const override { static_assert( - ck_tile::reflect::conv::HasConvTraits, + ck_tile::reflect::HasConvTraits, "ConvTraits specialization not found for this device operation. " "If you modified the template parameters of this class, ensure that " "the corresponding ConvTraits specialization in "