Skip to content

Commit 31674ff

Browse files
committed
Refactor ntt cast.
1 parent bb35bab commit 31674ff

File tree

3 files changed

+220
-246
lines changed

3 files changed

+220
-246
lines changed

ntt/include/nncase/ntt/arch/riscv64/ukernels.h

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
547547
static constexpr size_t unroll = 8;
548548
};
549549

550+
#if 0
550551
// cast
551552
template <> struct u_cast_policy<true> {
552553
static constexpr size_t unroll = 8;
@@ -899,6 +900,183 @@ DEFINE_U_CAST_2_1(half, 16, float_e4m3_t, 8, _Float16, int8_t, f16, float8e4m3)
899900
DEFINE_U_CAST_1_2(half, 16, float, 32, _Float16, float, float16, f32)
900901
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
901902
DEFINE_U_CAST_1_2(float_e4m3_t, 8, half, 16, int8_t, _Float16, float8e4m3, f16)
903+
#endif
904+
#else
905+
// cast
906+
template <> struct u_cast_policy<true> {
907+
static constexpr size_t unroll = 8;
908+
};
909+
910+
#define DEFINE_U_CAST_2_1(IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
911+
OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
912+
OUT_INTRINSIC_ELEM) \
913+
template <template <class> class TPostOps, class Stride> \
914+
struct u_cast<true, vector<IN_ELEM, 2, NTT_VLEN / IN_BW>, \
915+
vector<OUT_ELEM, NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
916+
public: \
917+
using T2Elem = OUT_ELEM; \
918+
using T1 = vector<IN_ELEM, 2, NTT_VLEN / IN_BW>; \
919+
using T2 = vector<OUT_ELEM, NTT_VLEN / OUT_BW>; \
920+
constexpr static size_t in_offset_scale = 2; \
921+
\
922+
constexpr void operator()(const T1 *input, Stride input_stride, \
923+
T2 *output, \
924+
[[maybe_unused]] Stride output_stride, \
925+
size_t count) noexcept { \
926+
using policy_t = u_cast_policy<true>; \
927+
constexpr auto unroll = policy_t::unroll; \
928+
while (count / unroll) { \
929+
auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
930+
auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
931+
auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
932+
auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
933+
auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
934+
auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
935+
auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
936+
auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
937+
\
938+
v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
939+
v2 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v2); \
940+
v4 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v4); \
941+
v6 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v6); \
942+
v8 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v8); \
943+
v10 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v10); \
944+
v12 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v12); \
945+
v14 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v14); \
946+
\
947+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v0), \
948+
"r"(output + 0 * output_stride) \
949+
: "memory"); \
950+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v2), \
951+
"r"(output + 1 * output_stride) \
952+
: "memory"); \
953+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v4), \
954+
"r"(output + 2 * output_stride) \
955+
: "memory"); \
956+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v6), \
957+
"r"(output + 3 * output_stride) \
958+
: "memory"); \
959+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v8), \
960+
"r"(output + 4 * output_stride) \
961+
: "memory"); \
962+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v10), \
963+
"r"(output + 5 * output_stride) \
964+
: "memory"); \
965+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v12), \
966+
"r"(output + 6 * output_stride) \
967+
: "memory"); \
968+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v14), \
969+
"r"(output + 7 * output_stride) \
970+
: "memory"); \
971+
output += unroll; \
972+
input += unroll; \
973+
count -= unroll; \
974+
} \
975+
\
976+
for (size_t i = 0; i < count; i++) { \
977+
auto v0 = ntt::cast_elem<T2Elem>(*input); \
978+
v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
979+
asm volatile("vs1r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m1_t)v0), \
980+
"r"(output) \
981+
: "memory"); \
982+
input += input_stride; \
983+
output += output_stride; \
984+
} \
985+
} \
986+
};
987+
988+
DEFINE_U_CAST_2_1(float, 32, half, 16, float, _Float16, f32, float16)
989+
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
990+
DEFINE_U_CAST_2_1(half, 16, float_e4m3_t, 8, _Float16, int8_t, f16, float8e4m3)
991+
#endif
992+
993+
994+
#define DEFINE_U_CAST_1_2(IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
995+
OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
996+
OUT_INTRINSIC_ELEM) \
997+
template <template <class> class TPostOps, class Stride> \
998+
struct u_cast<true, vector<IN_ELEM, NTT_VLEN / IN_BW>, \
999+
vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
1000+
constexpr void \
1001+
operator()(const vector<IN_ELEM, NTT_VLEN / IN_BW> *input, \
1002+
[[maybe_unused]] Stride input_stride, \
1003+
vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW> *output, \
1004+
Stride output_stride, size_t count) noexcept { \
1005+
using policy_t = u_cast_policy<true>; \
1006+
constexpr auto unroll = policy_t::unroll; \
1007+
constexpr auto half_unroll = unroll / 2; \
1008+
\
1009+
using T2Elem = OUT_ELEM; \
1010+
using T1 = vector<IN_ELEM, NTT_VLEN / IN_BW>; \
1011+
using T2 = vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>; \
1012+
[[maybe_unused]] constexpr static size_t out_offset_scale = 2; \
1013+
\
1014+
while (count / unroll) { \
1015+
constexpr auto vl_in = NTT_VLEN / IN_BW; \
1016+
constexpr auto vl_out = 2 * NTT_VLEN / OUT_BW; \
1017+
auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
1018+
auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
1019+
auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
1020+
auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
1021+
auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
1022+
auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
1023+
auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
1024+
auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
1025+
\
1026+
v0 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v0); \
1027+
v2 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v2); \
1028+
v4 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v4); \
1029+
v6 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v6); \
1030+
v8 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v8); \
1031+
v10 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v10); \
1032+
v12 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v12); \
1033+
v14 = TPostOps<vector<OUT_ELEM, 2, NTT_VLEN / OUT_BW>>()(v14); \
1034+
\
1035+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v0), \
1036+
"r"(output + 0 * output_stride) \
1037+
: "memory"); \
1038+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v2), \
1039+
"r"(output + 1 * output_stride) \
1040+
: "memory"); \
1041+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v4), \
1042+
"r"(output + 2 * output_stride) \
1043+
: "memory"); \
1044+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v6), \
1045+
"r"(output + 3 * output_stride) \
1046+
: "memory"); \
1047+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v8), \
1048+
"r"(output + 4 * output_stride) \
1049+
: "memory"); \
1050+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v10), \
1051+
"r"(output + 5 * output_stride) \
1052+
: "memory"); \
1053+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v12), \
1054+
"r"(output + 6 * output_stride) \
1055+
: "memory"); \
1056+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v14), \
1057+
"r"(output + 7 * output_stride) \
1058+
: "memory"); \
1059+
input += unroll; \
1060+
output += unroll; \
1061+
count -= unroll; \
1062+
} \
1063+
for (size_t i = 0; i < count; i++) { \
1064+
auto v0 = ntt::cast_elem<T2Elem>(*input); \
1065+
v0 = TPostOps<vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>>()(v0); \
1066+
asm volatile("vs2r.v %0, (%1);" ::"vr"((v##OUT_INTRINSIC_ELEM##m2_t)v0), \
1067+
"r"(output) \
1068+
: "memory"); \
1069+
input += input_stride; \
1070+
output += output_stride; \
1071+
} \
1072+
} \
1073+
};
1074+
1075+
DEFINE_U_CAST_1_2(half, 16, float, 32, _Float16, float, float16, float32)
1076+
#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
1077+
DEFINE_U_CAST_1_2(float_e4m3_t, 8, half, 16, int8_t, _Float16, float8e4m3, f16)
1078+
#endif
1079+
9021080
#endif
9031081

9041082
template <Scalar TProbs, Scalar TIndices, size_t Rank, size_t Axis, bool Norm>

0 commit comments

Comments
 (0)