@@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
538538 return current == index_size;
539539}
540540
541+ template <typename IndexType, typename OffsetType>
542+ static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec (
543+ const int bit_rate,
544+ const int64_t block_size,
545+ const int64_t output_size,
546+ const int64_t index_size,
547+ const int64_t uncompressed_data_size,
548+ const uint8_t * input,
549+ const IndexType* indices,
550+ const int32_t * compressed_indices_table,
551+ const OffsetType* offsets_or_lengths,
552+ const float * weights,
553+ const bool normalize_by_lengths,
554+ float * out,
555+ const bool is_weight_positional,
556+ const bool use_offsets) {
557+ if (uncompressed_data_size < 0 ) {
558+ return false ;
559+ }
560+
561+ // block_size is the number of elements and fused_block_size is the size in
562+ // bytes of an entire row, including scale and bias.
563+ const int num_elem_per_byte = 8 / bit_rate;
564+ const int64_t scale_bias_size = 2 * sizeof (float16);
565+ const uint64_t scale_bias_offset = div_up (block_size, num_elem_per_byte);
566+ const int64_t fused_block_size = scale_bias_offset + scale_bias_size;
567+
568+ int64_t current = 0 ;
569+ float * buf = out;
570+ for (int64_t m = 0 ; m < output_size; ++m) {
571+ const OffsetType len = use_offsets
572+ ? offsets_or_lengths[m + 1 ] - offsets_or_lengths[m]
573+ : offsets_or_lengths[m];
574+ const int64_t end = current + len;
575+ if (end > index_size) {
576+ return false ;
577+ }
578+
579+ memset (buf, 0 , sizeof (float ) * block_size);
580+
581+ const float * weights_addr = weights != nullptr
582+ ? (is_weight_positional ? weights : weights + current)
583+ : nullptr ;
584+ for (; current < end; ++current) {
585+ int64_t uncompressed_idx = indices[current];
586+ if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
587+ return false ;
588+ }
589+ int64_t idx = compressed_indices_table[uncompressed_idx];
590+ if (idx == -1 ) {
591+ weights_addr++;
592+ continue ;
593+ }
594+
595+ const uint8_t * input_row_base = input + fused_block_size * idx;
596+ const uint8_t * scale_bias_addr = input_row_base + scale_bias_offset;
597+
598+ float scale =
599+ cpu_half2float (*reinterpret_cast <const float16*>(scale_bias_addr));
600+ float bias = cpu_half2float (
601+ *reinterpret_cast <const float16*>(scale_bias_addr + sizeof (float16)));
602+
603+ if (weights != nullptr ) {
604+ float weight = *weights_addr++;
605+ scale *= weight;
606+ bias *= weight;
607+ }
608+
609+ const uint8_t * input_row = input_row_base;
610+ if (bit_rate == 4 ) {
611+ int64_t j = 0 ;
612+ #ifdef FBGEMM_VECTOR_WIDTH
613+ for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2 ));
614+ j += 2 ) {
615+ uint8_t tmp = *input_row++;
616+ float quantized1 = float (tmp & 0xf );
617+ float quantized2 = float (tmp >> 4 );
618+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
619+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
620+ }
621+ #endif
622+ for (; j < block_size; j += 2 ) {
623+ uint8_t tmp = *input_row++;
624+ float quantized1 = float (tmp & 0xf );
625+ float quantized2 = float (tmp >> 4 );
626+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
627+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
628+ }
629+ } else if (bit_rate == 2 ) {
630+ int64_t j = 0 ;
631+ #ifdef FBGEMM_VECTOR_WIDTH
632+ for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4 ));
633+ j += 4 ) {
634+ uint8_t tmp = *input_row++;
635+ float quantized1 = float (tmp & 0x3 );
636+ float quantized2 = float ((tmp & 0xC ) >> 2 );
637+ float quantized3 = float ((tmp & 0x30 ) >> 4 );
638+ float quantized4 = float (tmp >> 6 );
639+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
640+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
641+ buf[j + 2 ] = std::fma (scale, quantized3, buf[j + 2 ] + bias);
642+ buf[j + 3 ] = std::fma (scale, quantized4, buf[j + 3 ] + bias);
643+ }
644+ #endif
645+ for (; j < block_size; j += 4 ) {
646+ uint8_t tmp = *input_row++;
647+ float quantized1 = float (tmp & 0x3 );
648+ float quantized2 = float ((tmp & 0xC ) >> 2 );
649+ float quantized3 = float ((tmp & 0x30 ) >> 4 );
650+ float quantized4 = float (tmp >> 6 );
651+ buf[j] = std::fma (scale, quantized1, buf[j] + bias);
652+ buf[j + 1 ] = std::fma (scale, quantized2, buf[j + 1 ] + bias);
653+ buf[j + 2 ] = std::fma (scale, quantized3, buf[j + 2 ] + bias);
654+ buf[j + 3 ] = std::fma (scale, quantized4, buf[j + 3 ] + bias);
655+ }
656+ }
657+ }
658+ if (normalize_by_lengths && len) {
659+ float scale = 1 .f / len;
660+ for (int j = 0 ; j < block_size; ++j) {
661+ buf[j] *= scale;
662+ }
663+ }
664+ buf += block_size;
665+ }
666+ return current == index_size;
667+ }
668+
541669// / @ingroup tbe-cpu-autovec
542670// /
543671// / Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
@@ -2153,6 +2281,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
21532281#undef INSTANTIATE_SPMDM_OUT_T
21542282#undef INSTANTIATE_SPMDM_BASE
21552283
2284+ template <typename IndexType, typename OffsetType>
2285+ typename EmbeddingSpMDMRowWiseSparseKernelSignature<
2286+ uint8_t ,
2287+ IndexType,
2288+ OffsetType>::Type
2289+ GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec (
2290+ int bit_rate,
2291+ int64_t block_size,
2292+ bool has_weight,
2293+ bool normalize_by_lengths,
2294+ [[maybe_unused]] int prefetch,
2295+ bool is_weight_positional,
2296+ bool use_offsets) {
2297+ assert ((bit_rate == 2 || bit_rate == 4 ) && " bit_rate must be 2 or 4" );
2298+ using specialization_helper::fixed;
2299+ using specialization_helper::match;
2300+ using specialization_helper::specialize;
2301+ using specialization_helper::var;
2302+
2303+ #define SPECIALIZE ( \
2304+ BIT_RATE, \
2305+ BLOCK_SIZE, \
2306+ HAS_WEIGHT, \
2307+ NORMALIZE_BY_LENGTHS, \
2308+ IS_WEIGHT_POSITIONAL, \
2309+ USE_OFFSETS) \
2310+ if (match (BIT_RATE, bit_rate) && match (BLOCK_SIZE, block_size) && \
2311+ match (HAS_WEIGHT, has_weight) && \
2312+ match (NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
2313+ match (IS_WEIGHT_POSITIONAL, is_weight_positional) && \
2314+ match (USE_OFFSETS, use_offsets)) { \
2315+ return [=](int64_t output_size, \
2316+ int64_t index_size, \
2317+ int64_t uncompressed_data_size, \
2318+ const uint8_t * input, \
2319+ const IndexType* indices, \
2320+ const OffsetType* offsets_or_lengths, \
2321+ const float * weights, \
2322+ float * out, \
2323+ const int32_t * compressed_indices_table) { \
2324+ if (specialize (HAS_WEIGHT, has_weight)) { \
2325+ __builtin_assume (weights != nullptr ); \
2326+ } else { \
2327+ weights = nullptr ; \
2328+ } \
2329+ return EmbeddingSpMDMNBitRowWiseSparse_autovec ( \
2330+ /* bit_rate=*/ specialize (BIT_RATE, bit_rate), \
2331+ /* block_size=*/ specialize (BLOCK_SIZE, block_size), \
2332+ /* output_size=*/ output_size, \
2333+ /* index_size=*/ index_size, \
2334+ /* uncompressed_data_size=*/ uncompressed_data_size, \
2335+ /* input=*/ input, \
2336+ /* indices=*/ indices, \
2337+ /* compressed_indices_table=*/ compressed_indices_table, \
2338+ /* offsets_or_lengths=*/ offsets_or_lengths, \
2339+ /* weights=*/ weights, /* normalize_by_lengths=*/ \
2340+ specialize (NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
2341+ /* out=*/ out, /* is_weight_positional=*/ \
2342+ specialize (IS_WEIGHT_POSITIONAL, is_weight_positional), \
2343+ /* use_offsets=*/ specialize (USE_OFFSETS, use_offsets)); \
2344+ }; \
2345+ }
2346+
2347+ SPECIALIZE (/* BIT_RATE*/ fixed (4 ), var, var, var, var, var);
2348+ SPECIALIZE (/* BIT_RATE*/ fixed (2 ), var, var, var, var, var);
2349+ abort (); // should not get here
2350+ #undef SPECIALIZE
2351+ }
2352+
2353+ #define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (INDEX_TYPE, OFFSET_TYPE ) \
2354+ template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
2355+ uint8_t , \
2356+ INDEX_TYPE, \
2357+ OFFSET_TYPE>::Type \
2358+ GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<INDEX_TYPE, OFFSET_TYPE>( \
2359+ int bit_rate, \
2360+ int64_t block_size, \
2361+ bool has_weight, \
2362+ bool normalize_by_lengths, \
2363+ int prefetch, \
2364+ bool is_weight_positional, \
2365+ bool use_offsets);
2366+
2367+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int32_t , int32_t )
2368+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int32_t , int64_t )
2369+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int64_t , int32_t )
2370+ INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE (int64_t , int64_t )
2371+
2372+ #undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE
2373+
21562374} // namespace fbgemm
21572375
21582376#endif // #ifdef __linux__
0 commit comments