Skip to content

Commit dc28839

Browse files
MatzeBfacebook-github-bot
authored andcommitted
Add EmbeddingSpMDMNBitRowWiseSparse autovectorized variant (#5244)
Summary: X-link: facebookresearch/FBGEMM#2235 This adds an autovectorized implementation of the `GenerateEmbeddingSpMDMNBitRowWiseSparse` kernels. This first version is around 10x improvement compared to the reference C++ implementation, but slower than the asmjit verison (which is x86 only), see list below for how to close this gap: Room for further enhacements (see some of the surrounding kernels that do implement them): - This is only specialized for `bit_rate` right now. Should specializing for common block sizes in the future (need to know what sizes are common). - Directly operates on the `out` buffer. This is good for unknown block-sizes, but if we specialize for fixed small block-sizes then a separate buffer is better as it can be promoted completely to vector registers (for fixed vector register size anyway, doesn't work for variable size AArch64 SVE registers). - No prefetching logic yet. Reviewed By: excelle08 Differential Revision: D89086019
1 parent ecf8466 commit dc28839

File tree

3 files changed

+253
-2
lines changed

3 files changed

+253
-2
lines changed

src/EmbeddingSpMDMAutovec.cc

Lines changed: 220 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
538538
return current == index_size;
539539
}
540540

541+
template <typename IndexType, typename OffsetType>
542+
static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec(
543+
const int bit_rate,
544+
const int64_t block_size,
545+
const int64_t output_size,
546+
const int64_t index_size,
547+
const int64_t uncompressed_data_size,
548+
const uint8_t* input,
549+
const IndexType* indices,
550+
const int32_t* compressed_indices_table,
551+
const OffsetType* offsets_or_lengths,
552+
const float* weights,
553+
const bool normalize_by_lengths,
554+
float* out,
555+
const bool is_weight_positional,
556+
const bool use_offsets) {
557+
if (uncompressed_data_size < 0) {
558+
return false;
559+
}
560+
561+
// block_size is the number of elements and fused_block_size is the size in
562+
// bytes of an entire row, including scale and bias.
563+
const int num_elem_per_byte = 8 / bit_rate;
564+
const int64_t scale_bias_size = 2 * sizeof(float16);
565+
const uint64_t scale_bias_offset = div_up(block_size, num_elem_per_byte);
566+
const int64_t fused_block_size = scale_bias_offset + scale_bias_size;
567+
568+
int64_t current = 0;
569+
float* buf = out;
570+
for (int64_t m = 0; m < output_size; ++m) {
571+
const OffsetType len = use_offsets
572+
? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
573+
: offsets_or_lengths[m];
574+
const int64_t end = current + len;
575+
if (end > index_size) {
576+
return false;
577+
}
578+
579+
memset(buf, 0, sizeof(float) * block_size);
580+
581+
const float* weights_addr = weights != nullptr
582+
? (is_weight_positional ? weights : weights + current)
583+
: nullptr;
584+
for (; current < end; ++current) {
585+
int64_t uncompressed_idx = indices[current];
586+
if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
587+
return false;
588+
}
589+
int64_t idx = compressed_indices_table[uncompressed_idx];
590+
if (idx == -1) {
591+
weights_addr++;
592+
continue;
593+
}
594+
595+
const uint8_t* input_row_base = input + fused_block_size * idx;
596+
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;
597+
598+
float scale =
599+
cpu_half2float(*reinterpret_cast<const float16*>(scale_bias_addr));
600+
float bias = cpu_half2float(
601+
*reinterpret_cast<const float16*>(scale_bias_addr + sizeof(float16)));
602+
603+
if (weights != nullptr) {
604+
float weight = *weights_addr++;
605+
scale *= weight;
606+
bias *= weight;
607+
}
608+
609+
const uint8_t* input_row = input_row_base;
610+
if (bit_rate == 4) {
611+
int64_t j = 0;
612+
#ifdef FBGEMM_VECTOR_WIDTH
613+
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2));
614+
j += 2) {
615+
uint8_t tmp = *input_row++;
616+
float quantized1 = float(tmp & 0xf);
617+
float quantized2 = float(tmp >> 4);
618+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
619+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
620+
}
621+
#endif
622+
for (; j < block_size; j += 2) {
623+
uint8_t tmp = *input_row++;
624+
float quantized1 = float(tmp & 0xf);
625+
float quantized2 = float(tmp >> 4);
626+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
627+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
628+
}
629+
} else if (bit_rate == 2) {
630+
int64_t j = 0;
631+
#ifdef FBGEMM_VECTOR_WIDTH
632+
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4));
633+
j += 4) {
634+
uint8_t tmp = *input_row++;
635+
float quantized1 = float(tmp & 0x3);
636+
float quantized2 = float((tmp & 0xC) >> 2);
637+
float quantized3 = float((tmp & 0x30) >> 4);
638+
float quantized4 = float(tmp >> 6);
639+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
640+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
641+
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
642+
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
643+
}
644+
#endif
645+
for (; j < block_size; j += 4) {
646+
uint8_t tmp = *input_row++;
647+
float quantized1 = float(tmp & 0x3);
648+
float quantized2 = float((tmp & 0xC) >> 2);
649+
float quantized3 = float((tmp & 0x30) >> 4);
650+
float quantized4 = float(tmp >> 6);
651+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
652+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
653+
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
654+
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
655+
}
656+
}
657+
}
658+
if (normalize_by_lengths && len) {
659+
float scale = 1.f / len;
660+
for (int j = 0; j < block_size; ++j) {
661+
buf[j] *= scale;
662+
}
663+
}
664+
buf += block_size;
665+
}
666+
return current == index_size;
667+
}
668+
541669
/// @ingroup tbe-cpu-autovec
542670
///
543671
/// Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
@@ -778,8 +906,8 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
778906
constexpr bool is8bit = std::is_same_v<InType, uint8_t>;
779907

780908
if constexpr (is8bit) {
781-
// block_size is the number of elements and fused_block_size is the size
782-
// of an entire row, including scale and bias.
909+
// block_size is the number of elements and fused_block_size is the size in
910+
// bytes of an entire row, including scale and bias.
783911
const auto scale_bias_offset = 2 * sizeof(float);
784912
const int64_t fused_block_size = block_size + scale_bias_offset;
785913
int64_t current = 0;
@@ -2223,6 +2351,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
22232351
#undef INSTANTIATE_SPMDM_OUT_T
22242352
#undef INSTANTIATE_SPMDM_BASE
22252353

2354+
template <typename IndexType, typename OffsetType>
2355+
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
2356+
uint8_t,
2357+
IndexType,
2358+
OffsetType>::Type
2359+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
2360+
int bit_rate,
2361+
int64_t block_size,
2362+
bool has_weight,
2363+
bool normalize_by_lengths,
2364+
[[maybe_unused]] int prefetch,
2365+
bool is_weight_positional,
2366+
bool use_offsets) {
2367+
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
2368+
using specialization_helper::fixed;
2369+
using specialization_helper::match;
2370+
using specialization_helper::specialize;
2371+
using specialization_helper::var;
2372+
2373+
#define SPECIALIZE( \
2374+
BIT_RATE, \
2375+
BLOCK_SIZE, \
2376+
HAS_WEIGHT, \
2377+
NORMALIZE_BY_LENGTHS, \
2378+
IS_WEIGHT_POSITIONAL, \
2379+
USE_OFFSETS) \
2380+
if (match(BIT_RATE, bit_rate) && match(BLOCK_SIZE, block_size) && \
2381+
match(HAS_WEIGHT, has_weight) && \
2382+
match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
2383+
match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \
2384+
match(USE_OFFSETS, use_offsets)) { \
2385+
return [=](int64_t output_size, \
2386+
int64_t index_size, \
2387+
int64_t uncompressed_data_size, \
2388+
const uint8_t* input, \
2389+
const IndexType* indices, \
2390+
const OffsetType* offsets_or_lengths, \
2391+
const float* weights, \
2392+
float* out, \
2393+
const int32_t* compressed_indices_table) { \
2394+
if (specialize(HAS_WEIGHT, has_weight)) { \
2395+
__builtin_assume(weights != nullptr); \
2396+
} else { \
2397+
weights = nullptr; \
2398+
} \
2399+
return EmbeddingSpMDMNBitRowWiseSparse_autovec( \
2400+
/*bit_rate=*/specialize(BIT_RATE, bit_rate), \
2401+
/*block_size=*/specialize(BLOCK_SIZE, block_size), \
2402+
/*output_size=*/output_size, \
2403+
/*index_size=*/index_size, \
2404+
/*uncompressed_data_size=*/uncompressed_data_size, \
2405+
/*input=*/input, \
2406+
/*indices=*/indices, \
2407+
/*compressed_indices_table=*/compressed_indices_table, \
2408+
/*offsets_or_lengths=*/offsets_or_lengths, \
2409+
/*weights=*/weights, /*normalize_by_lengths=*/ \
2410+
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
2411+
/*out=*/out, /*is_weight_positional=*/ \
2412+
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
2413+
/*use_offsets=*/specialize(USE_OFFSETS, use_offsets)); \
2414+
}; \
2415+
}
2416+
2417+
SPECIALIZE(/*BIT_RATE*/ fixed(4), var, var, var, var, var);
2418+
SPECIALIZE(/*BIT_RATE*/ fixed(2), var, var, var, var, var);
2419+
abort(); // should not get here
2420+
#undef SPECIALIZE
2421+
}
2422+
2423+
#define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(INDEX_TYPE, OFFSET_TYPE) \
2424+
template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
2425+
uint8_t, \
2426+
INDEX_TYPE, \
2427+
OFFSET_TYPE>::Type \
2428+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<INDEX_TYPE, OFFSET_TYPE>( \
2429+
int bit_rate, \
2430+
int64_t block_size, \
2431+
bool has_weight, \
2432+
bool normalize_by_lengths, \
2433+
int prefetch, \
2434+
bool is_weight_positional, \
2435+
bool use_offsets);
2436+
2437+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int32_t)
2438+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int64_t)
2439+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int32_t)
2440+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int64_t)
2441+
2442+
#undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE
2443+
22262444
} // namespace fbgemm
22272445

22282446
#endif // #ifdef __linux__

src/EmbeddingSpMDMAutovec.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec(
9292
bool is_weight_positional,
9393
bool use_offsets);
9494

95+
template <typename IndexType, typename OffsetType>
96+
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
97+
uint8_t,
98+
IndexType,
99+
OffsetType>::Type
100+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
101+
int bit_rate,
102+
int64_t block_size,
103+
bool has_weight,
104+
bool normalize_by_lengths,
105+
int prefetch,
106+
bool is_weight_positional,
107+
bool use_offsets);
108+
95109
} // namespace fbgemm
96110

97111
#endif // #ifdef __linux__

src/EmbeddingSpMDMNBit.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,25 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
12961296
}
12971297
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
12981298

1299+
#ifdef FBGEMM_AUTOVEC_AVAILABLE
1300+
if (!cpuinfo_initialize()) {
1301+
throw std::runtime_error("Failed to initialize cpuinfo!");
1302+
}
1303+
if ((fbgemmHasArmSve2Support() && !is_autovec_disabled()) ||
1304+
is_autovec_forced()) {
1305+
return GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<
1306+
/*IndexType=*/indxType,
1307+
/*OffsetType=*/offsetType>(
1308+
/*bit_rate=*/bit_rate,
1309+
/*block_size=*/block_size,
1310+
/*has_weight=*/has_weight,
1311+
/*normalize_by_lengths=*/normalize_by_lengths,
1312+
/*prefetch=*/prefetch,
1313+
/*is_weight_positional=*/is_weight_positional,
1314+
/*use_offsets=*/use_offsets);
1315+
}
1316+
#endif
1317+
12991318
#ifdef VLOG
13001319
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
13011320
#endif

0 commit comments

Comments
 (0)