Skip to content

bench: fix GPU power throttling in benchmark utilities#2899

Open
Edenzzzz wants to merge 5 commits intoflashinfer-ai:mainfrom
Edenzzzz:fix_cudagraph_bench
Open

bench: fix GPU power throttling in benchmark utilities#2899
Edenzzzz wants to merge 5 commits intoflashinfer-ai:mainfrom
Edenzzzz:fix_cudagraph_bench

Conversation

@Edenzzzz
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz commented Mar 26, 2026

Summary

  • Fix GPU power throttling in bench_gpu_time_with_cuda_event, bench_gpu_time_with_cudagraph, and bench_gpu_time_with_cupti that causes artificially lower benchmark numbers
  • Insert cooldown gaps every ~5ms of sustained compute to prevent SM clock throttling
  • For cudagraph, also cap num_iters_within_graph to keep single-replay duration under the threshold

Problem

When setting repeat instead of repeat_time_ms, running kernels back-to-back causes sustained peak power draw on modern GPUs, forcing clock frequency throttling. On B200:

  • SM clocks drop from 1965 MHz to ~1580 MHz (19% reduction)
  • Power spikes to ~1000W (vs ~250W with natural CPU gaps)
  • TFLOPS degrades by up to 15% over many iterations

Before fix — per-iteration TFLOPS degradation over 100 event-timed iterations (b=4 s=8192 h=16 d=128, BF16 FA4 on B200):

Iters   0-  9: avg=1425 TFLOPS
Iters  10- 19: avg=1450 TFLOPS  (peak)
Iters  30- 39: avg=1373 TFLOPS  (throttling begins)
Iters  70- 79: avg=1237 TFLOPS  (fully throttled, -15%)

Fix

Introduce GPU_POWER_THROTTLE_THRESHOLD_MS = 5.0 — the empirical threshold (on B200) before sustained compute triggers clock throttling. All three timing paths now insert sync+sleep cooldown gaps at this interval.

After fix — Event vs Graph comparison (60 iters each, BF16 FA4 on B200):

Config Events TFLOPS Graph TFLOPS Delta
b=1 s=256 h=16 d=128 35 59 +67.7%
b=1 s=1024 h=16 d=128 338 457 +35.2%
b=4 s=4096 h=16 d=128 1413 1381 -2.3%
b=4 s=8192 h=16 d=128 1432 1428 -0.2%
b=2 s=16384 h=16 d=128 1413 1425 +0.8%
b=1 s=32768 h=16 d=128 1443 1455 +0.9%
b=4 s=4096 h=32 d=128 1306 1390 +6.4%
b=4 s=8192 h=32 d=128 1393 1409 +1.1%
b=1 s=32768 h=12 d=128 1375 1395 +1.4%
b=1 s=32768 h=24 d=128 1450 1457 +0.5%

Graph ≥ events for 8/10 shapes.

Reproducer

Install flash-attention-fp4 and run the FP4 FA4 benchmark on a B200 GPU:

git clone https://github.com/hao-ai-lab/flash-attention-fp4.git && cd flash-attention-fp4
pip install -e .
cd flash_attn/cute
CUTE_DSL_ENABLE_TVM_FFI=1 python benchmarks/bench_fp4.py

See benchmark results and setup instructions. The throttling is observable on any compute-bound kernel with repeat_iters > 30 on B200.

Test plan

  • Verified per-iteration TFLOPS stability over 60 iterations — all 3 paths (Events, Graph, CUPTI)
  • Verified graph ≥ events for 8/10 FA4 benchmark shapes (B200)
  • Verified tiny kernels still benefit from graph launch overhead reduction (+35-68%)
  • CUPTI verified on CUDA 13.0 B200: stable 1625 TFLOPS across 60 iterations (vacant GPU)
  • Test on other GPU architectures (H100, A100) to validate threshold generality

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Improvements
    • Enhanced GPU benchmarking utility with improved thermal management to provide more reliable kernel timing measurements and prevent performance variations due to power throttling.

Edenzzzz and others added 3 commits March 26, 2026 17:06
When bench_gpu_time_with_cudagraph runs many kernels back-to-back within
a single graph replay, sustained peak power draw forces the GPU to
throttle clock frequency (up to 20% on B200), producing artificially
lower benchmark numbers.

Fix: (1) Cap num_iters_within_graph so total graph duration stays under
5ms (empirical power throttling threshold). (2) Insert sync+sleep gaps
between graph replays when cumulative compute would exceed the threshold.

Before: b=4 s=8192 h=16 d=128 BF16 FA4 graph=1268 vs event=1455 (-13%)
After:  graph=1423 vs event=1446 (-1.6%), xlarge shapes now +10% faster

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Sustained back-to-back kernel execution causes GPU clock throttling (up
to 20% on B200), producing artificially lower benchmark numbers. This
affects both bench_gpu_time_with_cuda_event (at high repeat_iters) and
bench_gpu_time_with_cudagraph (within and between graph replays).

Fix: Insert sync+sleep cooldown gaps every ~5ms of sustained compute to
let GPU clocks recover. For cudagraph, also cap num_iters_within_graph
so a single graph replay doesn't exceed the throttling threshold.

Before (b=4 s=8192 h=16 d=128, 60 iters):
  Events iter 0-9: 1450 TFLOPS → iter 70-79: 1237 TFLOPS (-15%)
  Graph(n=10): 1268 TFLOPS vs events: 1455 (-13%)

After (same config, 60 iters):
  Events: stable 1399 TFLOPS throughout
  Graph: 1427 TFLOPS (+2% over events, from reduced launch overhead)

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 26, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 89143757-c07f-4235-aeab-98ab4a5603e2

📥 Commits

Reviewing files that changed from the base of the PR and between 9d3a1ab and 224e851.

📒 Files selected for processing (1)
  • flashinfer/testing/utils.py

📝 Walkthrough

Walkthrough

Added a GPU power-throttling mitigation mechanism to benchmarking utilities: introduced a module-level threshold and inserted periodic synchronization/sleep cooldowns in GPU timing loops; CUDA-graph path gains a pre-probe to estimate kernel/graph duration and adjust replay counts accordingly.

Changes

Cohort / File(s) Summary
GPU power-throttle mitigation
flashinfer/testing/utils.py
Added GPU_POWER_THROTTLE_THRESHOLD_MS constant. In bench_gpu_time_with_cuda_event and CUPTI path, compute iters_per_burst from estimated kernel time and insert periodic torch.cuda.synchronize() + time.sleep(...) when sleep_after_run is false. In bench_gpu_time_with_cudagraph, add a temporary-graph probe to measure single-kernel/graph duration, reduce num_iters_within_graph if threshold exceeded, compute replays_per_burst, and insert periodic synchronization/sleep after replay bursts. Updated comments to note graph timing is via replay.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐰
I measure hops in bursts so bright,
Then pause to let the circuits light,
A sync, a nap, a gentle sigh,
So kernels hum and temps stay shy,
Hop on—smooth runs through day and night.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main fix: addressing GPU power throttling in benchmark utilities, which directly matches the primary change.
Description check ✅ Passed The PR description comprehensively covers the problem, fix, results, and test plan, but the PR checklist items (pre-commit, tests) are incomplete.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to prevent GPU power throttling during benchmarking by inserting synchronization and sleep intervals when sustained compute exceeds a 5ms threshold. It includes logic to estimate kernel execution times and dynamically adjust iteration counts within CUDA graphs to maintain stable clock frequencies. Feedback was provided to cap the sleep duration at 10ms to prevent benchmarks from becoming excessively slow when executing long-running kernels.

sleep_after_kernel_run(estimated_kernel_execution_time)
elif (iter_idx + 1) % iters_per_burst == 0:
torch.cuda.synchronize()
time.sleep(estimated_kernel_execution_time / 1000)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sleep duration here is unbounded. If estimated_kernel_execution_time is very large (e.g., >1s), this will cause the benchmark to sleep for a long time after each burst, significantly slowing it down.

Consider capping the sleep duration to a reasonable value, for instance 10ms, to prevent unexpectedly long benchmark runs while still providing an effective cooldown period.

Suggested change
time.sleep(estimated_kernel_execution_time / 1000)
time.sleep(min(estimated_kernel_execution_time / 1000, 0.01))

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
flashinfer/testing/utils.py (2)

33-43: Consider making the throttle threshold configurable or architecture-aware.

The 5.0ms threshold is empirically measured on B200. Per the PR description, testing on H100 and A100 is pending. Different GPU architectures may have different power throttling characteristics, and a hardcoded value may not be optimal across all hardware.

Consider either:

  1. Adding an optional parameter to the benchmarking functions to override this threshold
  2. Using torch.cuda.get_device_properties() to detect architecture and adjust accordingly
  3. At minimum, documenting that this value is B200-specific
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/utils.py` around lines 33 - 43, The constant
GPU_POWER_THROTTLE_THRESHOLD_MS is hardcoded to a B200-specific value; make the
threshold configurable and architecture-aware by replacing direct uses of
GPU_POWER_THROTTLE_THRESHOLD_MS with a resolved value that can be overridden:
add an optional parameter (e.g., throttle_threshold_ms) to the
benchmarking/measurement entrypoints that currently rely on
GPU_POWER_THROTTLE_THRESHOLD_MS, implement a helper (e.g.,
resolve_throttle_threshold) that uses torch.cuda.get_device_properties() to
return a sensible default per GPU family (B200 vs H100/A100) and falls back to
the original 5.0ms, and update callers to pass through the new parameter or call
the helper; also update the module docstring to note the default is B200-derived
and refer to the new override.

1489-1493: Add a warning when num_iters_within_graph is silently reduced.

When the caller's requested num_iters_within_graph would cause throttling, the code silently reduces it. Callers (e.g., those relying on the default of 10) might be surprised when their measurement granularity changes unexpectedly.

🔧 Suggested fix
     max_sustained_ms = GPU_POWER_THROTTLE_THRESHOLD_MS
     if single_kernel_ms * num_iters_within_graph > max_sustained_ms:
+        original_iters = num_iters_within_graph
         num_iters_within_graph = max(1, int(max_sustained_ms / single_kernel_ms))
+        warnings.warn(
+            f"num_iters_within_graph reduced from {original_iters} to "
+            f"{num_iters_within_graph} to avoid GPU power throttling "
+            f"(single kernel: {single_kernel_ms:.2f}ms, threshold: {max_sustained_ms}ms).",
+            category=UserWarning,
+            stacklevel=2,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/utils.py` around lines 1489 - 1493, The code silently
reduces num_iters_within_graph when single_kernel_ms * num_iters_within_graph
exceeds GPU_POWER_THROTTLE_THRESHOLD_MS; change this to log or warn the caller
before mutating the value by emitting a clear message that includes the original
requested num_iters_within_graph, the computed reduced value, single_kernel_ms,
and GPU_POWER_THROTTLE_THRESHOLD_MS so callers are aware the granularity
changed; locate the logic around num_iters_within_graph and single_kernel_ms and
add a warnings.warn or logger.warning call just prior to assigning the reduced
num_iters_within_graph.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/testing/utils.py`:
- Around line 919-923: The current change makes sleep_after_run=False
accidentally trigger a cooldown sleep; restore opt-out by adding a new boolean
parameter (e.g., disable_throttle_mitigation=False) to the function that
contains sleep_after_run and use it in the branch logic so: if sleep_after_run:
call sleep_after_kernel_run(...); elif not disable_throttle_mitigation and
(iter_idx + 1) % iters_per_burst == 0: torch.cuda.synchronize();
time.sleep(estimated_kernel_execution_time / 1000). Update the function
signature and any callers that need the new behavior and adjust docstring to
describe disable_throttle_mitigation, referencing sleep_after_run,
sleep_after_kernel_run, estimated_kernel_execution_time, iter_idx, and
iters_per_burst.
- Around line 1546-1551: The new branch that injects cooldown gaps when
sleep_after_run is False changes behavior for existing callers; restore previous
semantics by making cooldown injection opt-in: add a boolean flag (e.g.,
inject_cooldown_when_no_sleep default False) to the function that contains this
block (and propagate it through bench_gpu_time_with_cuda_event callers), then
change the branch to only perform the torch.cuda.synchronize()/time.sleep(...)
cooldown when inject_cooldown_when_no_sleep is True and (iter_idx + 1) %
replays_per_burst == 0; keep the existing
sleep_after_kernel_run(estimated_kernel_execution_time) path unchanged and
ensure the cooldown still happens after end_events[iter_idx].record() so
measurements remain correct.

---

Nitpick comments:
In `@flashinfer/testing/utils.py`:
- Around line 33-43: The constant GPU_POWER_THROTTLE_THRESHOLD_MS is hardcoded
to a B200-specific value; make the threshold configurable and architecture-aware
by replacing direct uses of GPU_POWER_THROTTLE_THRESHOLD_MS with a resolved
value that can be overridden: add an optional parameter (e.g.,
throttle_threshold_ms) to the benchmarking/measurement entrypoints that
currently rely on GPU_POWER_THROTTLE_THRESHOLD_MS, implement a helper (e.g.,
resolve_throttle_threshold) that uses torch.cuda.get_device_properties() to
return a sensible default per GPU family (B200 vs H100/A100) and falls back to
the original 5.0ms, and update callers to pass through the new parameter or call
the helper; also update the module docstring to note the default is B200-derived
and refer to the new override.
- Around line 1489-1493: The code silently reduces num_iters_within_graph when
single_kernel_ms * num_iters_within_graph exceeds
GPU_POWER_THROTTLE_THRESHOLD_MS; change this to log or warn the caller before
mutating the value by emitting a clear message that includes the original
requested num_iters_within_graph, the computed reduced value, single_kernel_ms,
and GPU_POWER_THROTTLE_THRESHOLD_MS so callers are aware the granularity
changed; locate the logic around num_iters_within_graph and single_kernel_ms and
add a warnings.warn or logger.warning call just prior to assigning the reduced
num_iters_within_graph.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ac4aba7a-2ca1-40a3-b4ae-54df3bcfd0af

📥 Commits

Reviewing files that changed from the base of the PR and between d426b18 and 3f97829.

📒 Files selected for processing (1)
  • flashinfer/testing/utils.py

Comment on lines 919 to +923
if sleep_after_run:
sleep_after_kernel_run(estimated_kernel_execution_time)
elif (iter_idx + 1) % iters_per_burst == 0:
torch.cuda.synchronize()
time.sleep(estimated_kernel_execution_time / 1000)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Semantic change: sleep_after_run=False now injects cooldown sleep.

Previously, sleep_after_run=False meant no sleeping between iterations. Now it triggers automatic cooldown gaps. This silently changes behavior for existing callers (e.g., benchmarks/routines/moe.py, benchmarks/routines/attention.py) that explicitly pass sleep_after_run=False expecting no additional delays.

Consider one of:

  1. Add a new parameter like disable_throttle_mitigation=False to let callers opt out
  2. Rename sleep_after_run to better reflect the new behavior (breaking change)
  3. Document this behavioral change prominently in the docstring

The measurement timing is correct since sleep occurs after end_events[iter_idx].record().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/utils.py` around lines 919 - 923, The current change makes
sleep_after_run=False accidentally trigger a cooldown sleep; restore opt-out by
adding a new boolean parameter (e.g., disable_throttle_mitigation=False) to the
function that contains sleep_after_run and use it in the branch logic so: if
sleep_after_run: call sleep_after_kernel_run(...); elif not
disable_throttle_mitigation and (iter_idx + 1) % iters_per_burst == 0:
torch.cuda.synchronize(); time.sleep(estimated_kernel_execution_time / 1000).
Update the function signature and any callers that need the new behavior and
adjust docstring to describe disable_throttle_mitigation, referencing
sleep_after_run, sleep_after_kernel_run, estimated_kernel_execution_time,
iter_idx, and iters_per_burst.

Comment on lines 1546 to +1551
if sleep_after_run:
sleep_after_kernel_run(estimated_kernel_execution_time)
elif (iter_idx + 1) % replays_per_burst == 0:
# Cooldown gap to prevent clock throttling from sustained compute.
torch.cuda.synchronize()
time.sleep(graph_duration_ms / 1000)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same semantic change concern as noted for bench_gpu_time_with_cuda_event.

The sleep_after_run=False branch now injects cooldown gaps, changing the behavior for existing callers. See the earlier comment on lines 919-923 for suggested mitigations.

The cooldown implementation itself is correct—sleep occurs after end_events[iter_idx].record() so it doesn't affect measurement accuracy.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/utils.py` around lines 1546 - 1551, The new branch that
injects cooldown gaps when sleep_after_run is False changes behavior for
existing callers; restore previous semantics by making cooldown injection
opt-in: add a boolean flag (e.g., inject_cooldown_when_no_sleep default False)
to the function that contains this block (and propagate it through
bench_gpu_time_with_cuda_event callers), then change the branch to only perform
the torch.cuda.synchronize()/time.sleep(...) cooldown when
inject_cooldown_when_no_sleep is True and (iter_idx + 1) % replays_per_burst ==
0; keep the existing sleep_after_kernel_run(estimated_kernel_execution_time)
path unchanged and ensure the cooldown still happens after
end_events[iter_idx].record() so measurements remain correct.

Edenzzzz and others added 2 commits March 26, 2026 20:08
CUPTI measures pure GPU kernel time, but throttled clocks still produce
longer kernel durations. Apply the same cooldown gap logic.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@Edenzzzz
Copy link
Copy Markdown
Contributor Author

Edenzzzz commented Apr 1, 2026

cc @yzh119 should be ready

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant