Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ concept EpilogueDescriptor = requires(T t) {

// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
concept ThreadClusterOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
} || requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 4>>;
Expand Down Expand Up @@ -231,16 +231,16 @@ concept SpecifiesLdsTransfer = requires(T t) {

// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
concept SpecifiesThreadClusterArrangeOrder = requires(T t) {
{ T::transfer.a.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
{ T::transfer.b.thread_cluster_arrange_order } -> ThreadClusterOrderDescriptor;
};

// Concept to check if a struct specifies source access order info.
template <typename T>
concept SpecifiesSourceAccessOrder = requires(T t) {
{ T::transfer.a.src_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.src_access_order } -> AccessOrderDescriptor;
{ T::transfer.a.src_access_order } -> ThreadClusterOrderDescriptor;
{ T::transfer.b.src_access_order } -> ThreadClusterOrderDescriptor;
};

// Concept to check if struct specifies block GEMM.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

#include <type_traits>
#include <concepts>
#include <utility>
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/arch/arch.hpp"

namespace ck_tile::builder {

Expand Down Expand Up @@ -45,4 +48,224 @@ concept AccessOrderLimits4D = requires {
(Value.Size() == 4));
};

namespace detail {

// Helper to check if access order is a valid permutation
template <auto Value>
constexpr bool is_valid_permutation()
{
constexpr auto size = Value.Size();

// Check all values are in range [0, size)
for(size_t i = 0; i < size; ++i)
{
if(Value[i] < 0 || Value[i] >= static_cast<decltype(Value[0])>(size))
return false;
}

// Check all values are unique (valid permutation)
for(size_t i = 0; i < size; ++i)
{
for(size_t j = i + 1; j < size; ++j)
{
if(Value[i] == Value[j])
return false;
}
}

return true;
}

} // namespace detail

// Generic access order limits. Must be a valid permutation of {0, 1, ..., Dims-1}.
// Works with both 3D and 4D (or any dimensionality) access orders.
template <auto Value, size_t Dims>
concept AccessOrderLimits = requires {
requires Value.Size() == Dims;
requires detail::is_valid_permutation<Value>();
};

namespace detail {

// Helper trait to get compile-time size from ck::Array
template <typename T>
concept HasStaticSize = requires {
{ T::Size() } -> std::convertible_to<size_t>;
};

// Helper trait to get compile-time size from std::array and similar
template <typename T>
concept HasTupleSize = requires {
{ std::tuple_size<T>::value } -> std::convertible_to<size_t>;
};

// Helper for dependent static_assert
template <typename>
constexpr bool always_false = false;

// Get compile-time size of a range
template <typename Range>
constexpr size_t get_range_size()
{
if constexpr(HasStaticSize<Range>)
{
return Range::Size();
}
else if constexpr(HasTupleSize<Range>)
{
return std::tuple_size_v<Range>;
}
else
{
static_assert(always_false<Range>, "Unsupported type of range object.");
}
}

// Fold expression implementation for product calculation
template <typename Range, size_t... Is>
constexpr auto get_cluster_size_impl(const Range& range, std::index_sequence<Is...>)
{
using value_type = std::remove_cvref_t<decltype(range[0])>;
return ((range[Is]) * ... * value_type{1});
}

// Generic function that calculates the product of all elements in a range
// Works with any indexable range with compile-time size (ck::Array, std::array, etc.)
template <typename Range>
requires requires(Range r) {
r[0]; // Must be indexable
get_range_size<Range>(); // Must have compile-time size
}
constexpr auto get_cluster_size(const Range& range)
{
return get_cluster_size_impl(range, std::make_index_sequence<get_range_size<Range>()>{});
}

// Calculate K dimension coverage (k0 * k1, with vectorization if applicable)
template <auto BlockTransfer>
constexpr auto get_k_coverage()
{
auto k0 = BlockTransfer.thread_cluster_dims[0];
auto k1 = BlockTransfer.thread_cluster_dims[2];
auto k_total = k0 * k1;

// If vectorization is on k0 (dim 0) or k1 (dim 2), multiply by vector size
if constexpr(BlockTransfer.src_vector_dim == 0 || BlockTransfer.src_vector_dim == 2)
{
k_total *= BlockTransfer.src_scalar_per_vector;
}

return k_total;
}

// Calculate M/N dimension coverage (m_n, with vectorization if applicable)
template <auto BlockTransfer>
constexpr auto get_mn_coverage()
{
auto mn = BlockTransfer.thread_cluster_dims[1];

// If vectorization is on m_n (dim 1), multiply by vector size
if constexpr(BlockTransfer.src_vector_dim == 1)
{
mn *= BlockTransfer.src_scalar_per_vector;
}

return mn;
}

template <size_t DataTypeSize>
constexpr auto get_data_max_vec_size()
{
constexpr auto max_vec_inst_size_bytes = get_max_mem_vec_inst_width();
static_assert(max_vec_inst_size_bytes % DataTypeSize == 0,
"The max vec instruction size is not a multiple of given data type size.");
return max_vec_inst_size_bytes / DataTypeSize;
}

} // namespace detail

// product of thread cluster lengths must be <= workgroup size
template <auto BlockTransfer, size_t BlockSize>
concept ValidBlockTransferClusterSize =
requires { requires detail::get_cluster_size(BlockTransfer.thread_cluster_dims) <= BlockSize; };

// Check that thread cluster covers the K and M dimensions for A transfer
template <auto ABlockTransfer, auto TileSize>
concept ThreadsCoverATile = requires {
// K dimension: k0 * k1 * (vectorization) must divide K
requires TileSize.k % detail::get_k_coverage<ABlockTransfer>() == 0;
// M dimension: m_n * (vectorization) must divide M
requires TileSize.m % detail::get_mn_coverage<ABlockTransfer>() == 0;
};

// Check that thread cluster covers the K and N dimensions for B transfer
template <auto BBlockTransfer, auto TileSize>
concept ThreadsCoverBTile = requires {
// K dimension: k0 * k1 * (vectorization) must divide K
requires TileSize.k % detail::get_k_coverage<BBlockTransfer>() == 0;
// N dimension: m_n * (vectorization) must divide N
requires TileSize.n % detail::get_mn_coverage<BBlockTransfer>() == 0;
};

template <auto CBlockTransfer, auto TileSize>
concept ThreadsCoverCTile = requires {
// M dimension: m_wave_per_xdl must divide M
requires TileSize.m % CBlockTransfer.thread_cluster_dims[1] == 0;
// N dimension: n_wave_per_xdl * (vectorization) must divide N
requires TileSize.n % (CBlockTransfer.thread_cluster_dims[3] *
CBlockTransfer.scalar_per_vector) == 0;
};

template <size_t Value>
concept IsPowerOf2 = (Value > 0) && ((Value & (Value - 1)) == 0);

template <size_t ScalarPerVec, size_t DataTypeSize>
concept IsVectorSizeValid =
IsPowerOf2<ScalarPerVec> && (ScalarPerVec <= detail::get_data_max_vec_size<DataTypeSize>());

// Composite concept for input block transfer validation (A)
// Includes all validations: vector transfer limits, access order, cluster size,
// vector size validity, and tile coverage
template <auto A_BLOCK_TRANSFER,
typename DataType,
size_t BLOCK_SIZE,
auto TILE_SIZE,
size_t DIMS = 3>
concept ValidABlockTransfer =
InputVectorTransferLimits<A_BLOCK_TRANSFER> &&
AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order, DIMS> &&
ValidBlockTransferClusterSize<A_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverATile<A_BLOCK_TRANSFER, TILE_SIZE>;

// Composite concept for input block transfer validation (B)
template <auto B_BLOCK_TRANSFER,
typename DataType,
size_t BLOCK_SIZE,
auto TILE_SIZE,
size_t DIMS = 3>
concept ValidBBlockTransfer =
InputVectorTransferLimits<B_BLOCK_TRANSFER> &&
AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order, DIMS> &&
AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order, DIMS> &&
ValidBlockTransferClusterSize<B_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.src_scalar_per_vector, sizeof(DataType)> &&
IsVectorSizeValid<B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverBTile<B_BLOCK_TRANSFER, TILE_SIZE>;

// Composite concept for output block transfer validation (C)
template <auto C_BLOCK_TRANSFER, typename DataType, size_t BLOCK_SIZE, auto TILE_SIZE>
concept ValidCBlockTransfer =
OutputVectorTransferLimits<C_BLOCK_TRANSFER> &&
ValidBlockTransferClusterSize<C_BLOCK_TRANSFER, BLOCK_SIZE> &&
IsVectorSizeValid<C_BLOCK_TRANSFER.scalar_per_vector, sizeof(DataType)> &&
ThreadsCoverCTile<C_BLOCK_TRANSFER, TILE_SIZE>;

// Usage: IsValidLayout<ACTUAL_LAYOUT, VALID_LAYOUT_1, VALID_LAYOUT_2, ...>
template <auto ACTUAL_LAYOUT, auto... VALID_LAYOUTS>
concept IsValidLayout = ck_tile::is_any_value_of(ACTUAL_LAYOUT, VALID_LAYOUTS...);

} // namespace ck_tile::builder
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ck_tile::builder::factory {
template <typename T, size_t ThreadClusterRank = 3>
concept TileTransferParameters =
SpecifiesBlockTransfer<T, ThreadClusterRank> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T>;
SpecifiesThreadClusterArrangeOrder<T> && SpecifiesSourceAccessOrder<T>;

template <typename T>
concept SpecifiesTileTransferParameters3D = TileTransferParameters<T, 3>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,55 @@ struct ConvFwdLargeTensorFactory
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();

// Check limits for the algorithm parameters.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
// Check limits for the data transfer parameters.
static_assert(ValidABlockTransfer<A_BLOCK_TRANSFER,
typename Types::InDataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidBBlockTransfer<B_BLOCK_TRANSFER,
typename Types::WeiDataType,
BLOCK.block_size,
BLOCK.per_block>);
static_assert(ValidCBlockTransfer<C_BLOCK_TRANSFER,
typename Types::OutDataType,
BLOCK.block_size,
BLOCK.per_block>);

using enum TensorLayout;
static_assert(IsValidLayout<SIGNATURE.input.config.layout,
G_NW_C_strided,
G_NHW_C_strided,
G_NDHW_C_strided,
GNWC,
GNHWC,
GNDHWC,
NWGC,
NHWGC,
NDHWGC> &&
A_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.weight.config.layout,
G_K_X_C_strided,
G_K_YX_C_strided,
G_K_ZYX_C_strided,
GKXC,
GKYXC,
GKZYXC,
KXGC,
KYXGC,
KZYXGC> &&
B_BLOCK_TRANSFER.src_vector_dim == 2);

static_assert(IsValidLayout<SIGNATURE.output.config.layout,
G_NW_K_strided,
G_NHW_K_strided,
G_NDHW_K_strided,
GNWK,
GNHWK,
GNDHWK,
NWGK,
NHWGK,
NDHWGK>);

// The forward convolution kernel class instance with large tensor support.
using Instance =
Expand Down
Loading
Loading