Skip to content

Commit 043def2

Browse files
MatzeBfacebook-github-bot
authored andcommitted
Add EmbeddingSpMDMNBitRowWiseSparse autovectorized variant (#5244)
Summary: X-link: facebookresearch/FBGEMM#2235 This adds an autovectorized implementation of the `GenerateEmbeddingSpMDMNBitRowWiseSparse` kernels. This first version is around 10x improvement compared to the reference C++ implementation, but slower than the asmjit verison (which is x86 only), see list below for how to close this gap: Room for further enhacements (see some of the surrounding kernels that do implement them): - This is only specialized for `bit_rate` right now. Specializing for common block sizes typically nets good improvements. - Directly operates on the `out` buffer. This is good for unknown block-sizes, but if we specialize for fixed small block-sizes then a separate buffer is better as it can be promoted completely to vector registers (for fixed vector register size anyway, doesn't work for variable size AArch64 SVE registers). - No prefetching logic yet. Reviewed By: excelle08 Differential Revision: D89086019
1 parent 04455cd commit 043def2

File tree

3 files changed

+251
-0
lines changed

3 files changed

+251
-0
lines changed

src/EmbeddingSpMDMAutovec.cc

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,134 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
538538
return current == index_size;
539539
}
540540

541+
template <typename IndexType, typename OffsetType>
542+
static bool ALWAYS_INLINE EmbeddingSpMDMNBitRowWiseSparse_autovec(
543+
const int bit_rate,
544+
const int64_t block_size,
545+
const int64_t output_size,
546+
const int64_t index_size,
547+
const int64_t uncompressed_data_size,
548+
const uint8_t* input,
549+
const IndexType* indices,
550+
const int32_t* compressed_indices_table,
551+
const OffsetType* offsets_or_lengths,
552+
const float* weights,
553+
const bool normalize_by_lengths,
554+
float* out,
555+
const bool is_weight_positional,
556+
const bool use_offsets) {
557+
if (uncompressed_data_size < 0) {
558+
return false;
559+
}
560+
561+
// block_size is the number of elements and fused_block_size is the size in
562+
// bytes of an entire row, including scale and bias.
563+
const int num_elem_per_byte = 8 / bit_rate;
564+
const int64_t scale_bias_size = 2 * sizeof(float16);
565+
const uint64_t scale_bias_offset = div_up(block_size, num_elem_per_byte);
566+
const int64_t fused_block_size = scale_bias_offset + scale_bias_size;
567+
568+
int64_t current = 0;
569+
float* buf = out;
570+
for (int64_t m = 0; m < output_size; ++m) {
571+
const OffsetType len = use_offsets
572+
? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
573+
: offsets_or_lengths[m];
574+
const int64_t end = current + len;
575+
if (end > index_size) {
576+
return false;
577+
}
578+
579+
memset(buf, 0, sizeof(float) * block_size);
580+
581+
const float* weights_addr = weights != nullptr
582+
? (is_weight_positional ? weights : weights + current)
583+
: nullptr;
584+
for (; current < end; ++current) {
585+
int64_t uncompressed_idx = indices[current];
586+
if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
587+
return false;
588+
}
589+
int64_t idx = compressed_indices_table[uncompressed_idx];
590+
if (idx == -1) {
591+
weights_addr++;
592+
continue;
593+
}
594+
595+
const uint8_t* input_row_base = input + fused_block_size * idx;
596+
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;
597+
598+
float scale =
599+
cpu_half2float(*reinterpret_cast<const float16*>(scale_bias_addr));
600+
float bias = cpu_half2float(
601+
*reinterpret_cast<const float16*>(scale_bias_addr + sizeof(float16)));
602+
603+
if (weights != nullptr) {
604+
float weight = *weights_addr++;
605+
scale *= weight;
606+
bias *= weight;
607+
}
608+
609+
const uint8_t* input_row = input_row_base;
610+
if (bit_rate == 4) {
611+
int64_t j = 0;
612+
#ifdef FBGEMM_VECTOR_WIDTH
613+
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2));
614+
j += 2) {
615+
uint8_t tmp = *input_row++;
616+
float quantized1 = float(tmp & 0xf);
617+
float quantized2 = float(tmp >> 4);
618+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
619+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
620+
}
621+
#endif
622+
for (; j < block_size; j += 2) {
623+
uint8_t tmp = *input_row++;
624+
float quantized1 = float(tmp & 0xf);
625+
float quantized2 = float(tmp >> 4);
626+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
627+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
628+
}
629+
} else if (bit_rate == 2) {
630+
int64_t j = 0;
631+
#ifdef FBGEMM_VECTOR_WIDTH
632+
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4));
633+
j += 4) {
634+
uint8_t tmp = *input_row++;
635+
float quantized1 = float(tmp & 0x3);
636+
float quantized2 = float((tmp & 0xC) >> 2);
637+
float quantized3 = float((tmp & 0x30) >> 4);
638+
float quantized4 = float(tmp >> 6);
639+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
640+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
641+
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
642+
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
643+
}
644+
#endif
645+
for (; j < block_size; j += 4) {
646+
uint8_t tmp = *input_row++;
647+
float quantized1 = float(tmp & 0x3);
648+
float quantized2 = float((tmp & 0xC) >> 2);
649+
float quantized3 = float((tmp & 0x30) >> 4);
650+
float quantized4 = float(tmp >> 6);
651+
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
652+
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
653+
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
654+
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
655+
}
656+
}
657+
}
658+
if (normalize_by_lengths && len) {
659+
float scale = 1.f / len;
660+
for (int j = 0; j < block_size; ++j) {
661+
buf[j] *= scale;
662+
}
663+
}
664+
buf += block_size;
665+
}
666+
return current == index_size;
667+
}
668+
541669
/// @ingroup tbe-cpu-autovec
542670
///
543671
/// Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
@@ -2153,6 +2281,96 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
21532281
#undef INSTANTIATE_SPMDM_OUT_T
21542282
#undef INSTANTIATE_SPMDM_BASE
21552283

2284+
template <typename IndexType, typename OffsetType>
2285+
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
2286+
uint8_t,
2287+
IndexType,
2288+
OffsetType>::Type
2289+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
2290+
int bit_rate,
2291+
int64_t block_size,
2292+
bool has_weight,
2293+
bool normalize_by_lengths,
2294+
[[maybe_unused]] int prefetch,
2295+
bool is_weight_positional,
2296+
bool use_offsets) {
2297+
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
2298+
using specialization_helper::fixed;
2299+
using specialization_helper::match;
2300+
using specialization_helper::specialize;
2301+
using specialization_helper::var;
2302+
2303+
#define SPECIALIZE( \
2304+
BIT_RATE, \
2305+
BLOCK_SIZE, \
2306+
HAS_WEIGHT, \
2307+
NORMALIZE_BY_LENGTHS, \
2308+
IS_WEIGHT_POSITIONAL, \
2309+
USE_OFFSETS) \
2310+
if (match(BIT_RATE, bit_rate) && match(BLOCK_SIZE, block_size) && \
2311+
match(HAS_WEIGHT, has_weight) && \
2312+
match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
2313+
match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \
2314+
match(USE_OFFSETS, use_offsets)) { \
2315+
return [=](int64_t output_size, \
2316+
int64_t index_size, \
2317+
int64_t uncompressed_data_size, \
2318+
const uint8_t* input, \
2319+
const IndexType* indices, \
2320+
const OffsetType* offsets_or_lengths, \
2321+
const float* weights, \
2322+
float* out, \
2323+
const int32_t* compressed_indices_table) { \
2324+
if (specialize(HAS_WEIGHT, has_weight)) { \
2325+
__builtin_assume(weights != nullptr); \
2326+
} else { \
2327+
weights = nullptr; \
2328+
} \
2329+
return EmbeddingSpMDMNBitRowWiseSparse_autovec( \
2330+
/*bit_rate=*/specialize(BIT_RATE, bit_rate), \
2331+
/*block_size=*/specialize(BLOCK_SIZE, block_size), \
2332+
/*output_size=*/output_size, \
2333+
/*index_size=*/index_size, \
2334+
/*uncompressed_data_size=*/uncompressed_data_size, \
2335+
/*input=*/input, \
2336+
/*indices=*/indices, \
2337+
/*compressed_indices_table=*/compressed_indices_table, \
2338+
/*offsets_or_lengths=*/offsets_or_lengths, \
2339+
/*weights=*/weights, /*normalize_by_lengths=*/ \
2340+
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
2341+
/*out=*/out, /*is_weight_positional=*/ \
2342+
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
2343+
/*use_offsets=*/specialize(USE_OFFSETS, use_offsets)); \
2344+
}; \
2345+
}
2346+
2347+
SPECIALIZE(/*BIT_RATE*/ fixed(4), var, var, var, var, var);
2348+
SPECIALIZE(/*BIT_RATE*/ fixed(2), var, var, var, var, var);
2349+
abort(); // should not get here
2350+
#undef SPECIALIZE
2351+
}
2352+
2353+
#define INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(INDEX_TYPE, OFFSET_TYPE) \
2354+
template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
2355+
uint8_t, \
2356+
INDEX_TYPE, \
2357+
OFFSET_TYPE>::Type \
2358+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<INDEX_TYPE, OFFSET_TYPE>( \
2359+
int bit_rate, \
2360+
int64_t block_size, \
2361+
bool has_weight, \
2362+
bool normalize_by_lengths, \
2363+
int prefetch, \
2364+
bool is_weight_positional, \
2365+
bool use_offsets);
2366+
2367+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int32_t)
2368+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int32_t, int64_t)
2369+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int32_t)
2370+
INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE(int64_t, int64_t)
2371+
2372+
#undef INSTANTIATE_SPMDM_NBIT_ROWWISE_SPARSE
2373+
21562374
} // namespace fbgemm
21572375

21582376
#endif // #ifdef __linux__

src/EmbeddingSpMDMAutovec.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,20 @@ GenerateEmbeddingSpMDMRowWiseSparse_autovec(
9292
bool is_weight_positional,
9393
bool use_offsets);
9494

95+
template <typename IndexType, typename OffsetType>
96+
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
97+
uint8_t,
98+
IndexType,
99+
OffsetType>::Type
100+
GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec(
101+
int bit_rate,
102+
int64_t block_size,
103+
bool has_weight,
104+
bool normalize_by_lengths,
105+
int prefetch,
106+
bool is_weight_positional,
107+
bool use_offsets);
108+
95109
} // namespace fbgemm
96110

97111
#endif // #ifdef __linux__

src/EmbeddingSpMDMNBit.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,25 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
12961296
}
12971297
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
12981298

1299+
#ifdef FBGEMM_AUTOVEC_AVAILABLE
1300+
if (!cpuinfo_initialize()) {
1301+
throw std::runtime_error("Failed to initialize cpuinfo!");
1302+
}
1303+
if ((fbgemmHasArmSve2Support() && !is_autovec_disabled()) ||
1304+
is_autovec_forced()) {
1305+
return GenerateEmbeddingSpMDMNBitRowWiseSparse_autovec<
1306+
/*IndexType=*/indxType,
1307+
/*OffsetType=*/offsetType>(
1308+
/*bit_rate=*/bit_rate,
1309+
/*block_size=*/block_size,
1310+
/*has_weight=*/has_weight,
1311+
/*normalize_by_lengths=*/normalize_by_lengths,
1312+
/*prefetch=*/prefetch,
1313+
/*is_weight_positional=*/is_weight_positional,
1314+
/*use_offsets=*/use_offsets);
1315+
}
1316+
#endif
1317+
12991318
#ifdef VLOG
13001319
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
13011320
#endif

0 commit comments

Comments
 (0)