@@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
538538 return current == index_size;
539539}
540540
541+ template <typename IndexType, typename OffsetType>
542+ static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec (
543+ const int bit_rate,
544+ const int64_t block_size,
545+ const int64_t output_size,
546+ const int64_t index_size,
547+ const int64_t uncompressed_data_size,
548+ const uint8_t * input,
549+ const IndexType* indices,
550+ const int32_t * compressed_indices_table,
551+ const OffsetType* offsets_or_lengths,
552+ const float * weights,
553+ const bool normalize_by_lengths,
554+ float * out,
555+ const bool is_weight_positional,
556+ const bool use_offsets) {
557+ if (uncompressed_data_size < 0 ) {
558+ return false ;
559+ }
560+
561+ // block_size is the number of elements and fused_block_size is the size in
562+ // bytes of an entire row, including scale and bias.
563+ const int num_elem_per_byte = 8 / bit_rate;
564+ const int64_t scale_bias_size = 2 * sizeof (float16);
565+ const uint64_t scale_bias_offset = div_up (block_size, num_elem_per_byte);
566+ const int64_t fused_block_size = scale_bias_offset + scale_bias_size;
567+
568+ int64_t current = 0 ;
569+ float * buf = out;
570+ for (int64_t m = 0 ; m < output_size; ++m) {
571+ const OffsetType len = use_offsets
572+ ? offsets_or_lengths[m + 1 ] - offsets_or_lengths[m]
573+ : offsets_or_lengths[m];
574+ const int64_t end = current + len;
575+ if (end > index_size) {
576+ return false ;
577+ }
578+
579+ memset (buf, 0 , sizeof (float ) * block_size);
580+
581+ const float * weights_addr = weights != nullptr
582+ ? (is_weight_positional ? weights : weights + current)
583+ : nullptr ;
584+ for (; current < end; ++current) {
585+ int64_t uncompressed_idx = indices[current];
586+ if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
587+ return false ;
588+ }
589+ int64_t idx = compressed_indices_table[uncompressed_idx];
590+ if (idx == -1 ) {
591+ weights_addr++;
592+ continue ;
593+ }
594+
595+ const uint8_t * input_row_base = input + fused_block_size * idx;
596+ const uint8_t * scale_bias_addr = input_row_base + scale_bias_offset;
597+
598+ float scale =
599+ cpu_half2float (*reinterpret_cast <const float16*>(scale_bias_addr));
600+ float bias = cpu_half2float (
601+ *reinterpret_cast <const float16*>(scale_bias_addr + sizeof (float16)));
602+
603+ if (weights != nullptr ) {
604+ float weight = *weights_addr++;
605+ scale *= weight;
606+ bias *= weight;
607+ }
608+
609+ const uint8_t * input_row = input_row_base;
610+ if (bit_rate == 4 ) {
611+ int64_t j = 0 ;
612+ #ifdef FBGEMM_VECTOR_WIDTH
613+ for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2 ));
614+ j += 2 ) {
615+ uint8_t tmp = *input_row++;
616+ float quantized1 = float (tmp & 0xf );
617+ float quantized2 = float (tmp >> 4 );
618+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
619+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
620+ }
621+ #endif
622+ for (; j < block_size; j += 2 ) {
623+ uint8_t tmp = *input_row++;
624+ float quantized1 = float (tmp & 0xf );
625+ float quantized2 = float (tmp >> 4 );
626+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
627+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
628+ }
629+ } else if (bit_rate == 2 ) {
630+ int64_t j = 0 ;
631+ #ifdef FBGEMM_VECTOR_WIDTH
632+ for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4 ));
633+ j += 4 ) {
634+ uint8_t tmp = *input_row++;
635+ float quantized1 = float (tmp & 0x3 );
636+ float quantized2 = float ((tmp & 0xC ) >> 2 );
637+ float quantized3 = float ((tmp & 0x30 ) >> 4 );
638+ float quantized4 = float (tmp >> 6 );
639+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
640+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
641+ buf[j + 2 ] = std::fma (scale, quantized3, buf[j + 2 ] + bias);
642+ buf[j + 3 ] = std::fma (scale, quantized4, buf[j + 3 ] + bias);
643+ }
644+ #endif
645+ for (; j < block_size; j += 4 ) {
646+ uint8_t tmp = *input_row++;
647+ float quantized1 = float (tmp & 0x3 );
648+ float quantized2 = float ((tmp & 0xC ) >> 2 );
649+ float quantized3 = float ((tmp & 0x30 ) >> 4 );
650+ float quantized4 = float (tmp >> 6 );
651+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
652+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
653+ buf[j + 2 ] = std::fma (scale, quantized3, buf[j + 2 ] + bias);
654+ buf[j + 3 ] = std::fma (scale, quantized4, buf[j + 3 ] + bias);
655+ }
656+ }
657+ }
658+ if (normalize_by_lengths && len) {
659+ float scale = 1 .f / len;
660+ for (int j = 0 ; j < block_size; ++j) {
661+ buf[j] *= scale;
662+ }
663+ }
664+ buf += block_size;
665+ }
666+ return current == index_size;
667+ }
668+
541669// / @ingroup tbe-cpu-autovec
542670// /
543671// / Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
@@ -778,8 +906,8 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
778906 constexpr bool is8bit = std::is_same_v<InType, uint8_t >;
779907
780908 if constexpr (is8bit) {
781- // block_size is the number of elements and fused_block_size is the size
782- // of an entire row, including scale and bias.
909+ // block_size is the number of elements and fused_block_size is the size in
910+ // bytes of an entire row, including scale and bias.
783911 const auto scale_bias_offset = 2 * sizeof (float );
784912 const int64_t fused_block_size = block_size + scale_bias_offset;
785913 int64_t current = 0 ;
@@ -2223,6 +2351,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
22232351#undef INSTANTIATE_SPMDM_OUT_T
22242352#undef INSTANTIATE_SPMDM_BASE
22252353
2354+ template <typename IndexType, typename OffsetType>
2355+ typename EmbeddingSpMDMRowWiseSparseKernelSignature<
2356+ uint8_t ,
2357+ IndexType,
2358+ OffsetType>::Type
2359+ GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec (
2360+ int bit_rate,
2361+ int64_t block_size,
2362+ bool has_weight,
2363+ bool normalize_by_lengths,
2364+ [[maybe_unused]] int prefetch,
2365+ bool is_weight_positional,
2366+ bool use_offsets) {
2367+ assert ((bit_rate == 2 || bit_rate == 4 ) && " bit_rate must be 2 or 4" );
2368+ using specialization_helper::fixed;
2369+ using specialization_helper::match;
2370+ using specialization_helper::specialize;
2371+ using specialization_helper::var;
2372+
2373+ #define SPECIALIZE ( \
2374+ BIT_RATE, \
2375+ BLOCK_SIZE, \
2376+ HAS_WEIGHT, \
2377+ NORMALIZE_BY_LENGTHS, \
2378+ IS_WEIGHT_POSITIONAL, \
2379+ USE_OFFSETS) \
2380+ if (match (BIT_RATE, bit_rate) && match (BLOCK_SIZE, block_size) && \
2381+ match (HAS_WEIGHT, has_weight) && \
2382+ match (NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
2383+ match (IS_WEIGHT_POSITIONAL, is_weight_positional) && \
2384+ match (USE_OFFSETS, use_offsets)) { \
2385+ return [=](int64_t output_size, \
2386+ int64_t index_size, \
2387+ int64_t uncompressed_data_size, \
2388+ const uint8_t * input, \
2389+ const IndexType* indices, \
2390+ const OffsetType* offsets_or_lengths, \
2391+ const float * weights, \
2392+ float * out, \
2393+ const int32_t * compressed_indices_table) { \
2394+ if (specialize (HAS_WEIGHT, has_weight)) { \
2395+ __builtin_assume (weights != nullptr ); \
2396+ } else { \
2397+ weights = nullptr ; \
2398+ } \
2399+ return EmbeddingSpMDMNBitRowWiseSparse_autovec ( \
2400+ /* bit_rate=*/ specialize (BIT_RATE, bit_rate), \
2401+ /* block_size=*/ specialize (BLOCK_SIZE, block_size), \
2402+ /* output_size=*/ output_size, \
2403+ /* index_size=*/ index_size, \
2404+ /* uncompressed_data_size=*/ uncompressed_data_size, \
2405+ /* input=*/ input, \
2406+ /* indices=*/ indices, \
2407+ /* compressed_indices_table=*/ compressed_indices_table, \
2408+ /* offsets_or_lengths=*/ offsets_or_lengths, \
2409+ /* weights=*/ weights, /* normalize_by_lengths=*/ \
2410+ specialize (NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
2411+ /* out=*/ out, /* is_weight_positional=*/ \
2412+ specialize (IS_WEIGHT_POSITIONAL, is_weight_positional), \
2413+ /* use_offsets=*/ specialize (USE_OFFSETS, use_offsets)); \
2414+ }; \
2415+ }
2416+
2417+ SPECIALIZE (/* BIT_RATE*/ fixed (4 ), var, var, var, var, var);
2418+ SPECIALIZE (/* BIT_RATE*/ fixed (2 ), var, var, var, var, var);
2419+ abort (); // should not get here
2420+ #undef SPECIALIZE
2421+ }
2422+
2423+ #define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (INDEX_TYPE, OFFSET_TYPE ) \
2424+ template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
2425+ uint8_t , \
2426+ INDEX_TYPE, \
2427+ OFFSET_TYPE>::Type \
2428+ GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<INDEX_TYPE, OFFSET_TYPE>( \
2429+ int bit_rate, \
2430+ int64_t block_size, \
2431+ bool has_weight, \
2432+ bool normalize_by_lengths, \
2433+ int prefetch, \
2434+ bool is_weight_positional, \
2435+ bool use_offsets);
2436+
2437+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int32_t , int32_t )
2438+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int32_t , int64_t )
2439+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int64_t , int32_t )
2440+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int64_t , int64_t )
2441+
2442+ #undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE
2443+
22262444} // namespace fbgemm
22272445
22282446#endif // #ifdef __linux__
0 commit comments