[CPU] Add software prefetch to overlap bandwidth for scaled_embedding_bag#4171
[CPU] Add software prefetch to overlap bandwidth for scaled_embedding_bag#4171yanbing-j wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4171
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c081cb6 with merge base 8fca033 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Xia-Weiwen
left a comment
There was a problem hiding this comment.
Please also try prefetching less data because sometimes overhead is big for CPU to issue too many prefetch instructions.
There was a problem hiding this comment.
Pull request overview
Adds software prefetching to the CPU AVX512 path of scaled_embedding_bag to better overlap DRAM latency for large embedding tables with random access patterns.
Changes:
- Introduce
_prefetch_emb_rowhelper to prefetch cache lines for a full embedding row. - Add a fixed prefetch lookahead (
PREFETCH_DIST) in the AVX512 kernel loop to prefetch future batch entries’ referenced rows.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Prefetch all cache lines of an embedding row (all blocks). | ||
| // emb_bytes = emb_dim * sizeof(data_t). Cache line = 64 bytes. | ||
| template <typename data_t> | ||
| static inline void _prefetch_emb_row(const data_t *base, int64_t emb_dim) { | ||
| const char *ptr = reinterpret_cast<const char *>(base); | ||
| const int64_t emb_bytes = emb_dim * static_cast<int64_t>(sizeof(data_t)); | ||
| for (int64_t off = 0; off < emb_bytes; off += 64) { | ||
| _mm_prefetch(ptr + off, _MM_HINT_T0); | ||
| } | ||
| } |
There was a problem hiding this comment.
_prefetch_emb_row uses the x86 intrinsic _mm_prefetch but the helper is defined unconditionally (outside any x86/AVX512 compile guards). This can break compilation on non-x86 targets (e.g., AArch64) even though the call site is inside the AVX512 path. Consider either (1) moving the helper definition under #if defined(CPU_CAPABILITY_AVX512) / an x86 guard, or (2) replacing _mm_prefetch with a portable alternative like __builtin_prefetch (and keeping the call site AVX512-gated).
There was a problem hiding this comment.
Good catch. @yanbing-j Do you think it makes sense? If so, please wrap this part with #if defined(CPU_CAPABILITY_AVX512) and provide a pass for #else.
There was a problem hiding this comment.
I have moved prefetch into CPU_CAPABILITY_AVX512, other scenario will go into fallback path.
| // How many batch entries ahead to prefetch. Each entry has ~3 rows to fetch | ||
| // from a 40M-row table; DRAM latency ~100 ns means we must keep enough | ||
| // in-flight requests to hide latency. | ||
| constexpr int64_t PREFETCH_DIST = 8; |
There was a problem hiding this comment.
Move this to #if defined(CPU_CAPABILITY_AVX512), too?
This PR is to add software prefetch to overlap bandwidth for scaled_embedding_bag.
The embedding table (e.g, 40M rows × 128 dims = 5 GB for fp8) in scaled_embedding_bag far exceeds the CPU's L3 cache, so every unique index access results in a DRAM miss that the hardware prefetcher cannot predict due to the random access pattern. Without software prefetching, each of the ~218 unique DRAM fetches per batch must be served serially at ~100 ns latency, making the kernel latency-bound (22 µs) rather than bandwidth-bound (~8 µs).