Skip to content

[CPU] Add software prefetch to overlap bandwidth for scaled_embedding_bag#4171

Open
yanbing-j wants to merge 3 commits intopytorch:mainfrom
yanbing-j:yanbing/optimize_scaled_embedding_bag_fp8
Open

[CPU] Add software prefetch to overlap bandwidth for scaled_embedding_bag#4171
yanbing-j wants to merge 3 commits intopytorch:mainfrom
yanbing-j:yanbing/optimize_scaled_embedding_bag_fp8

Conversation

@yanbing-j
Copy link
Contributor

This PR is to add software prefetch to overlap bandwidth for scaled_embedding_bag.

The embedding table (e.g, 40M rows × 128 dims = 5 GB for fp8) in scaled_embedding_bag far exceeds the CPU's L3 cache, so every unique index access results in a DRAM miss that the hardware prefetcher cannot predict due to the random access pattern. Without software prefetching, each of the ~218 unique DRAM fetches per batch must be served serially at ~100 ns latency, making the kernel latency-bound (22 µs) rather than bandwidth-bound (~8 µs).

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 25, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4171

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c081cb6 with merge base 8fca033 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 25, 2026
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also try prefetching less data because sometimes overhead is big for CPU to issue too many prefetch instructions.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds software prefetching to the CPU AVX512 path of scaled_embedding_bag to better overlap DRAM latency for large embedding tables with random access patterns.

Changes:

  • Introduce _prefetch_emb_row helper to prefetch cache lines for a full embedding row.
  • Add a fixed prefetch lookahead (PREFETCH_DIST) in the AVX512 kernel loop to prefetch future batch entries’ referenced rows.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +177 to +186
// Prefetch all cache lines of an embedding row (all blocks).
// emb_bytes = emb_dim * sizeof(data_t). Cache line = 64 bytes.
template <typename data_t>
static inline void _prefetch_emb_row(const data_t *base, int64_t emb_dim) {
const char *ptr = reinterpret_cast<const char *>(base);
const int64_t emb_bytes = emb_dim * static_cast<int64_t>(sizeof(data_t));
for (int64_t off = 0; off < emb_bytes; off += 64) {
_mm_prefetch(ptr + off, _MM_HINT_T0);
}
}
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_prefetch_emb_row uses the x86 intrinsic _mm_prefetch but the helper is defined unconditionally (outside any x86/AVX512 compile guards). This can break compilation on non-x86 targets (e.g., AArch64) even though the call site is inside the AVX512 path. Consider either (1) moving the helper definition under #if defined(CPU_CAPABILITY_AVX512) / an x86 guard, or (2) replacing _mm_prefetch with a portable alternative like __builtin_prefetch (and keeping the call site AVX512-gated).

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. @yanbing-j Do you think it makes sense? If so, please wrap this part with #if defined(CPU_CAPABILITY_AVX512) and provide a pass for #else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have moved prefetch into CPU_CAPABILITY_AVX512, other scenario will go into fallback path.

// How many batch entries ahead to prefetch. Each entry has ~3 rows to fetch
// from a 40M-row table; DRAM latency ~100 ns means we must keep enough
// in-flight requests to hide latency.
constexpr int64_t PREFETCH_DIST = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to #if defined(CPU_CAPABILITY_AVX512), too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review March 26, 2026 08:48
@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 March 26, 2026 08:48
@Xia-Weiwen Xia-Weiwen added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants