Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp32")
{
return run_gemm_example_prec_type<GemmConfig, Invoker, float>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig,
Expand Down
44 changes: 32 additions & 12 deletions example/ck_tile/03_gemm/gemm_basic_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ struct BasicInvoker
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
{
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
}

constexpr bool is_fp32_input = std::is_same_v<ADataType, float>;
[[maybe_unused]] constexpr bool is_tf32_compute =
std::is_same_v<ComputeDataType, ck_tile::tf32_t>;

// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t K_Tile = 64;

#if CK_TILE_USE_WMMA
Expand All @@ -37,13 +42,24 @@ struct BasicInvoker
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(CK_GFX950_SUPPORT)
// gfx950: fp32 uses 16x16x16 tile (native MFMA)
// tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
// Fallback or other architectures
constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif

Expand All @@ -61,11 +77,15 @@ struct BasicInvoker
BLayout,
CLayout>;

using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;

using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;

Expand Down
11 changes: 11 additions & 0 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
Invoker,
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig<float, float>,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
Expand Down
21 changes: 13 additions & 8 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct SplitKTwoStageInvoker
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)

{
Expand Down Expand Up @@ -61,13 +62,17 @@ struct SplitKTwoStageInvoker
GemmConfig::Preshuffle>;
constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
Expand Down
7 changes: 7 additions & 0 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig<float>, float, float, float, ck_tile::tf32_t>(
a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Expand Down
20 changes: 19 additions & 1 deletion example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,24 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

template <>
struct GemmTypeConfig<float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::half_t>
{
Expand Down Expand Up @@ -446,7 +464,7 @@ inline auto create_args()
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
Expand Down
7 changes: 7 additions & 0 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig<float>, float, float, float, ck_tile::tf32_t>(
a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Expand Down
19 changes: 12 additions & 7 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ struct WeightPreshuffleInvoker
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)

{
Expand Down Expand Up @@ -48,12 +49,16 @@ struct WeightPreshuffleInvoker
GemmConfig::Preshuffle>;
constexpr auto scheduler = GemmConfig::Scheduler;

using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;

using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
Expand Down
57 changes: 32 additions & 25 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ template <typename GemmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
typename ComputeDataType = ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
Expand Down Expand Up @@ -151,7 +152,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
DsLayout,
CLayout,
true,
CDEElementWise>(
CDEElementWise,
ComputeDataType>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
Expand All @@ -169,7 +171,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
DsLayout,
CLayout,
false,
CDEElementWise>(
CDEElementWise,
ComputeDataType>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
Expand Down Expand Up @@ -209,11 +212,12 @@ std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_ge
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout = ck_tile::tensor_layout::gemm::RowMajor,
typename BLayout = ck_tile::tensor_layout::gemm::ColumnMajor,
typename CLayout = ck_tile::tensor_layout::gemm::RowMajor,
typename ComputeDataType = ADataType>
int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
Expand Down Expand Up @@ -349,21 +353,22 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);
CLayout,
ComputeDataType>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);

c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

Expand Down Expand Up @@ -393,7 +398,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,

if(arg_parser.get_int("v") == 1)
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType, ComputeDataType>(
a_m_k, b_k_n, c_m_n_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
Expand Down Expand Up @@ -427,7 +432,9 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
CLayout,
ComputeDataType>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);

c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());

Expand Down
11 changes: 8 additions & 3 deletions example/ck_tile/03_gemm/run_gemm_example_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
template <typename GemmConfig,
typename Invoker,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
typename BPrecType = APrecType,
typename CPrecType = APrecType,
typename ComputeDataType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
Expand Down Expand Up @@ -54,7 +55,11 @@ int run_gemm_example_prec_type(std::string a_layout,
Invoker,
APrecType,
BPrecType,
CPrecType>(
CPrecType,
decltype(a_layout_type),
decltype(b_layout_type),
Row,
ComputeDataType>(
arg_parser, a_layout_type, b_layout_type, Row{});
}
},
Expand Down
11 changes: 11 additions & 0 deletions example/ck_tile/03_gemm/universal_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
#ifdef CK_GFX950_SUPPORT
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig<float>,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
#endif
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Expand Down
Loading