@@ -547,6 +547,7 @@ template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
547547 static constexpr size_t unroll = 8 ;
548548};
549549
550+ #if 0
550551// cast
551552template <> struct u_cast_policy<true> {
552553 static constexpr size_t unroll = 8;
@@ -899,6 +900,183 @@ DEFINE_U_CAST_2_1(half, 16, float_e4m3_t, 8, _Float16, int8_t, f16, float8e4m3)
899900DEFINE_U_CAST_1_2(half, 16, float, 32, _Float16, float, float16, f32)
900901#if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
901902DEFINE_U_CAST_1_2(float_e4m3_t, 8, half, 16, int8_t, _Float16, float8e4m3, f16)
903+ #endif
904+ #else
905+ // cast
906+ template <> struct u_cast_policy <true > {
907+ static constexpr size_t unroll = 8 ;
908+ };
909+
910+ #define DEFINE_U_CAST_2_1 (IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
911+ OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
912+ OUT_INTRINSIC_ELEM) \
913+ template <template <class > class TPostOps , class Stride > \
914+ struct u_cast <true , vector<IN_ELEM, 2 , NTT_VLEN / IN_BW>, \
915+ vector<OUT_ELEM, NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
916+ public: \
917+ using T2Elem = OUT_ELEM; \
918+ using T1 = vector<IN_ELEM, 2 , NTT_VLEN / IN_BW>; \
919+ using T2 = vector<OUT_ELEM, NTT_VLEN / OUT_BW>; \
920+ constexpr static size_t in_offset_scale = 2 ; \
921+ \
922+ constexpr void operator ()(const T1 *input, Stride input_stride, \
923+ T2 *output, \
924+ [[maybe_unused]] Stride output_stride, \
925+ size_t count) noexcept { \
926+ using policy_t = u_cast_policy<true >; \
927+ constexpr auto unroll = policy_t ::unroll; \
928+ while (count / unroll) { \
929+ auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
930+ auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
931+ auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
932+ auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
933+ auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
934+ auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
935+ auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
936+ auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
937+ \
938+ v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
939+ v2 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v2); \
940+ v4 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v4); \
941+ v6 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v6); \
942+ v8 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v8); \
943+ v10 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v10); \
944+ v12 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v12); \
945+ v14 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v14); \
946+ \
947+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v0), \
948+ " r" (output + 0 * output_stride) \
949+ : " memory" ); \
950+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v2), \
951+ " r" (output + 1 * output_stride) \
952+ : " memory" ); \
953+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v4), \
954+ " r" (output + 2 * output_stride) \
955+ : " memory" ); \
956+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v6), \
957+ " r" (output + 3 * output_stride) \
958+ : " memory" ); \
959+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v8), \
960+ " r" (output + 4 * output_stride) \
961+ : " memory" ); \
962+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v10), \
963+ " r" (output + 5 * output_stride) \
964+ : " memory" ); \
965+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v12), \
966+ " r" (output + 6 * output_stride) \
967+ : " memory" ); \
968+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v14), \
969+ " r" (output + 7 * output_stride) \
970+ : " memory" ); \
971+ output += unroll; \
972+ input += unroll; \
973+ count -= unroll; \
974+ } \
975+ \
976+ for (size_t i = 0 ; i < count; i++) { \
977+ auto v0 = ntt::cast_elem<T2Elem>(*input); \
978+ v0 = TPostOps<vector<OUT_ELEM, NTT_VLEN / OUT_BW>>()(v0); \
979+ asm volatile (" vs1r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m1_t )v0), \
980+ " r" (output) \
981+ : " memory" ); \
982+ input += input_stride; \
983+ output += output_stride; \
984+ } \
985+ } \
986+ };
987+
988+ DEFINE_U_CAST_2_1 (float , 32 , half, 16 , float , _Float16, f32 , float16)
989+ #if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
990+ DEFINE_U_CAST_2_1 (half, 16 , float_e4m3_t , 8 , _Float16, int8_t , f16 , float8e4m3)
991+ #endif
992+
993+
994+ #define DEFINE_U_CAST_1_2 (IN_ELEM, IN_BW, OUT_ELEM, OUT_BW, IN_BUILTIN_ELEM, \
995+ OUT_BUILTIN_ELEM, IN_INTRINSIC_ELEM, \
996+ OUT_INTRINSIC_ELEM) \
997+ template <template <class > class TPostOps , class Stride > \
998+ struct u_cast <true , vector<IN_ELEM, NTT_VLEN / IN_BW>, \
999+ vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>, TPostOps, Stride> { \
1000+ constexpr void \
1001+ operator ()(const vector<IN_ELEM, NTT_VLEN / IN_BW> *input, \
1002+ [[maybe_unused]] Stride input_stride, \
1003+ vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW> *output, \
1004+ Stride output_stride, size_t count) noexcept { \
1005+ using policy_t = u_cast_policy<true >; \
1006+ constexpr auto unroll = policy_t ::unroll; \
1007+ constexpr auto half_unroll = unroll / 2 ; \
1008+ \
1009+ using T2Elem = OUT_ELEM; \
1010+ using T1 = vector<IN_ELEM, NTT_VLEN / IN_BW>; \
1011+ using T2 = vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>; \
1012+ [[maybe_unused]] constexpr static size_t out_offset_scale = 2 ; \
1013+ \
1014+ while (count / unroll) { \
1015+ constexpr auto vl_in = NTT_VLEN / IN_BW; \
1016+ constexpr auto vl_out = 2 * NTT_VLEN / OUT_BW; \
1017+ auto v0 = ntt::cast_elem<T2Elem>(*(input + 0 * input_stride)); \
1018+ auto v2 = ntt::cast_elem<T2Elem>(*(input + 1 * input_stride)); \
1019+ auto v4 = ntt::cast_elem<T2Elem>(*(input + 2 * input_stride)); \
1020+ auto v6 = ntt::cast_elem<T2Elem>(*(input + 3 * input_stride)); \
1021+ auto v8 = ntt::cast_elem<T2Elem>(*(input + 4 * input_stride)); \
1022+ auto v10 = ntt::cast_elem<T2Elem>(*(input + 5 * input_stride)); \
1023+ auto v12 = ntt::cast_elem<T2Elem>(*(input + 6 * input_stride)); \
1024+ auto v14 = ntt::cast_elem<T2Elem>(*(input + 7 * input_stride)); \
1025+ \
1026+ v0 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v0); \
1027+ v2 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v2); \
1028+ v4 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v4); \
1029+ v6 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v6); \
1030+ v8 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v8); \
1031+ v10 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v10); \
1032+ v12 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v12); \
1033+ v14 = TPostOps<vector<OUT_ELEM, 2 , NTT_VLEN / OUT_BW>>()(v14); \
1034+ \
1035+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v0), \
1036+ " r" (output + 0 * output_stride) \
1037+ : " memory" ); \
1038+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v2), \
1039+ " r" (output + 1 * output_stride) \
1040+ : " memory" ); \
1041+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v4), \
1042+ " r" (output + 2 * output_stride) \
1043+ : " memory" ); \
1044+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v6), \
1045+ " r" (output + 3 * output_stride) \
1046+ : " memory" ); \
1047+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v8), \
1048+ " r" (output + 4 * output_stride) \
1049+ : " memory" ); \
1050+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v10), \
1051+ " r" (output + 5 * output_stride) \
1052+ : " memory" ); \
1053+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v12), \
1054+ " r" (output + 6 * output_stride) \
1055+ : " memory" ); \
1056+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v14), \
1057+ " r" (output + 7 * output_stride) \
1058+ : " memory" ); \
1059+ input += unroll; \
1060+ output += unroll; \
1061+ count -= unroll; \
1062+ } \
1063+ for (size_t i = 0 ; i < count; i++) { \
1064+ auto v0 = ntt::cast_elem<T2Elem>(*input); \
1065+ v0 = TPostOps<vector<OUT_ELEM, 2 * NTT_VLEN / OUT_BW>>()(v0); \
1066+ asm volatile (" vs2r.v %0, (%1);" ::" vr" ((v##OUT_INTRINSIC_ELEM##m2_t )v0), \
1067+ " r" (output) \
1068+ : " memory" ); \
1069+ input += input_stride; \
1070+ output += output_stride; \
1071+ } \
1072+ } \
1073+ };
1074+
1075+ DEFINE_U_CAST_1_2 (half, 16 , float , 32 , _Float16, float , float16, float32)
1076+ #if defined(NNCASE_XPU_MODULE) && defined(SYS_MODE)
1077+ DEFINE_U_CAST_1_2 (float_e4m3_t , 8 , half, 16 , int8_t , _Float16, float8e4m3, f16 )
1078+ #endif
1079+
9021080#endif
9031081
9041082template <Scalar TProbs, Scalar TIndices, size_t Rank, size_t Axis, bool Norm>
0 commit comments