Skip to content

Yanqinz/dynamic shape unified api#2910

Open
yanqinz2 wants to merge 6 commits intomainfrom
yanqinz/dynamic-shape-unified-api
Open

Yanqinz/dynamic shape unified api#2910
yanqinz2 wants to merge 6 commits intomainfrom
yanqinz/dynamic-shape-unified-api

Conversation

@yanqinz2
Copy link
Copy Markdown
Collaborator

@yanqinz2 yanqinz2 commented Mar 29, 2026

📌 Description

Add cudnn override shape support for bf16 and fp4 gemm.
Current bf16/nvfp4 gemm cudnn backend uses static shape approach, which rebuild the execution plan every time a new shape is encountered, which introduces tons of compilation overhead and repeated cached graphs.
This PR introduces override shape support and makes it as default option for cudnn backend, which reuse the pre-built cudnn graph and execution plans during the autotuning phase for the following problem sizes encountered, avoid repeated graph/execution plan caching and compilation overhead.

🔍 Related Issues

https://nvbugspro.nvidia.com/bug/5539146

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Standardized cuDNN GEMM public API names, removing redundant exported symbols.
  • New Features

    • Added an optional policy parameter to cuDNN graph builders (defaults to heuristic choice).
  • Bug Fixes

    • Improved workspace buffer reallocation for dynamic-shape operations.
    • Corrected shape/stride handling for override-shape execution and added safer fallback when override support is unavailable.
  • Tests

    • Updated tests to use renamed APIs and adjusted FP4 quantization calls.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Renamed cuDNN GEMM override-shape FP4/MXFP8 APIs, added optional policy to graph builders, changed execution to derive override shapes/strides from expanded tensors, ensure workspace reallocation when needed, and added BF16 override-shape caching with conditional fallback to non-override execution.

Changes

Cohort / File(s) Summary
Public API Exports
flashinfer/gemm/__init__.py
Replaced exported FP4/MXFP8 override-shape names with build_cudnn_gemm_fp4_graph_override_shape, execute_cudnn_gemm_fp4_graph_override_shape, build_cudnn_gemm_mxfp8_graph_override_shape, and execute_cudnn_gemm_mxfp8_graph_override_shape.
Core Implementation
flashinfer/gemm/gemm_base.py
Renamed FP4/MXFP8 builders/executors; added optional policy param and use of graph.build_plans(policy); derive override shapes/strides from expanded (real) tensors; reallocate workspace when insufficient; add BF16 override-shape graph caching and conditional fallback to non-override path.
Tests
tests/gemm/test_cudnn_override_shape.py
Updated imports and call sites to new FP4/MXFP8 builder/executor names; replaced nvfp4_quantize(..., True) with fp4_quantize(...); preserved dynamic-m cached-graph behavior with updated APIs.

Sequence Diagram(s)

sequenceDiagram
  participant Runner as Runner
  participant Builder as cudnn Graph Builder
  participant Graph as cudnn Graph
  participant GPU as GPU / Workspace
  participant Tensors as Prepared Tensors

  Runner->>Builder: build_cudnn_gemm_*_graph_override_shape(params, policy?)
  Builder->>Graph: create graph object
  Builder->>Graph: build_plans(policy)
  Graph-->>Runner: graph (cached by m / power-of-two)
  Runner->>Tensors: expand packed FP4/MXFP8/BF16 -> compute override shapes & strides
  Runner->>GPU: query graph.get_workspace_size()
  alt workspace too small
    GPU-->>Runner: reallocate workspace_buffer
  end
  Runner->>Graph: execute_override_shape(graph, tensors, workspace_buffer)
  Graph->>GPU: run GEMM with override shapes/strides
  GPU-->>Runner: result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • dhiraj113
  • aleozlx
  • bkryu

Poem

🐰 Hopped through names and plans with glee,
Builders learn policies, tensors expand for me,
Workspaces grow when sizes surprise,
BF16 caches dance in power-of-two guise,
A rabbit cheers — small changes, big spree.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Yanqinz/dynamic shape unified api' is vague and generic, using the author's branch name rather than a clear description of the changes. Consider using a more descriptive title like 'Add cuDNN dynamic-shape support for BF16 and FP4 GEMM' or 'Unify cuDNN GEMM graph API naming scheme'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description covers the main objective and includes a related issue link, but test completion status is not checked off.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yanqinz/dynamic-shape-unified-api

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the cuDNN GEMM implementation to provide consistent dynamic-shape (override-shape) support across FP4, MXFP8, and BF16 data types. Key changes include renaming functions for better naming consistency, adding a policy parameter to graph builders, and integrating override-shape logic into the TunableRunner classes for BF16 and FP4. The review feedback focuses on improving the efficiency of workspace buffer handling by suggesting that the code raise a ValueError for undersized buffers instead of performing local re-allocations, and recommends refactoring duplicated logic within the runner classes to improve maintainability.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/gemm/test_cudnn_override_shape.py (1)

17-27: Import the public helpers through flashinfer.gemm in this test.

Right now this bypasses flashinfer.gemm, so the test still passes even if the package re-export layer regresses. Since this PR changes that surface, the test should exercise it.

🧪 Minimal import split
-from flashinfer.gemm.gemm_base import (
-    CUDNN_AVAILABLE,
-    build_cudnn_gemm_bf16_graph_override_shape,
-    execute_cudnn_gemm_bf16_graph_override_shape,
-    build_cudnn_gemm_fp4_graph_override_shape,
-    execute_cudnn_gemm_fp4_graph_override_shape,
-    build_cudnn_gemm_mxfp8_graph_override_shape,
-    execute_cudnn_gemm_mxfp8_graph_override_shape,
-    is_cudnn_override_shape_available,
-    _calculate_block_scale_dims,
-)
+from flashinfer.gemm import (
+    build_cudnn_gemm_bf16_graph_override_shape,
+    execute_cudnn_gemm_bf16_graph_override_shape,
+    build_cudnn_gemm_fp4_graph_override_shape,
+    execute_cudnn_gemm_fp4_graph_override_shape,
+    build_cudnn_gemm_mxfp8_graph_override_shape,
+    execute_cudnn_gemm_mxfp8_graph_override_shape,
+    is_cudnn_override_shape_available,
+)
+from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE, _calculate_block_scale_dims
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_cudnn_override_shape.py` around lines 17 - 27, The test
directly imports helpers from flashinfer.gemm.gemm_base instead of exercising
the package re-export layer; update the import to import the public helpers from
flashinfer.gemm (e.g. import CUDNN_AVAILABLE,
build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 25-28: Restore backward-compatible aliases for the renamed
override-shape exports by reintroducing the old names as simple assignments to
the new symbols: for example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 2121-2124: Replace the runtime assertion with explicit input
validation that raises a ValueError: instead of using "assert real_a_stride[2]
== 1 and real_b_stride[1] == 1", check those conditions and raise ValueError
with a clear message (e.g., "a and b must be k-major") so invalid
caller-provided strides cannot slip through when Python assertions are disabled;
update the code around _get_bf16_3d_shape_stride and the
real_a_stride/real_b_stride checks accordingly.

---

Nitpick comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 17-27: The test directly imports helpers from
flashinfer.gemm.gemm_base instead of exercising the package re-export layer;
update the import to import the public helpers from flashinfer.gemm (e.g. import
CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 463f5a46-a140-49a0-8802-735d5fba51af

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and 368656f.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gemm/gemm_base.py (2)

1955-1966: ⚠️ Potential issue | 🟠 Major

Keep deprecated aliases for the renamed override-shape helpers.

These functions are part of the exported flashinfer.gemm surface, so renaming them without keeping the previous symbols is a hard import break for downstream users. Please keep the old FP4/MXFP8 names as aliases for at least one release cycle, or explicitly treat this PR as a breaking API change.

Also applies to: 2090-2100, 2209-2219, 2321-2330

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1955 - 1966, The renamed
override-shape helper functions must keep their old FP4/MXFP8 symbols as
deprecated aliases: add thin alias definitions that point the old names to the
new implementations (e.g., define the previous FP4/MXFP8 function names to
return/call build_cudnn_gemm_fp4_graph_override_shape and the other two renamed
helpers referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.

2321-2365: ⚠️ Potential issue | 🟠 Major

Apply shape normalization to MXFP8 override-shape executor for consistency.

execute_cudnn_gemm_mxfp8_graph_override_shape() forwards raw tensor shapes/strides, whereas execute_cudnn_gemm_bf16_graph_override_shape() uses _get_bf16_3d_shape_stride() to normalize 2D inputs to 3D. Even though MXFP8 currently enforces 3D inputs, the inconsistency should be resolved by using the same normalization pattern:

Suggested normalization
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
     a,
     b,
@@
 ):
     """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides."""
+    a_shape, a_stride = _get_bf16_3d_shape_stride(a)
+    b_shape, b_stride = _get_bf16_3d_shape_stride(b)
+    batch = a_shape[0]
+    a_descale_shape, a_descale_stride = _expand_block_scale_tensor_shape(
+        a_descale, batch
+    )
+    b_descale_shape, b_descale_stride = _expand_block_scale_tensor_shape(
+        b_descale, batch
+    )
+    c_shape, c_stride = _get_bf16_3d_shape_stride(c_final)
+
     variant_pack = {
         UIDs.A_UID.value: a,
         UIDs.B_UID.value: b,
         UIDs.BLOCK_DESCALE_A_UID.value: a_descale,
@@
     override_shapes = [
-        list(a.shape),
-        list(b.shape),
-        list(a_descale.shape),
-        list(b_descale.shape),
-        list(c_final.shape),
+        list(a_shape),
+        list(b_shape),
+        list(a_descale_shape),
+        list(b_descale_shape),
+        list(c_shape),
     ]
     override_strides = [
-        list(a.stride()),
-        list(b.stride()),
-        list(a_descale.stride()),
-        list(b_descale.stride()),
-        list(c_final.stride()),
+        list(a_stride),
+        list(b_stride),
+        list(a_descale_stride),
+        list(b_descale_stride),
+        list(c_stride),
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2321 - 2365, The override-shape
executor execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw
tensor shapes/strides; update it to use the same normalization used by the BF16
path by calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale,
b_descale, c_final) and use the returned normalized shape and stride values when
building override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

2113-2115: ⚠️ Potential issue | 🟡 Minor

Use ValueError instead of assert for the k-major guard.

Line 2115 is validating caller input. Under python -O, that assert disappears and invalid layouts can slip through to cuDNN.

🛠️ Safer runtime validation
-    assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major"
+    if real_a_stride[2] != 1 or real_b_stride[1] != 1:
+        raise ValueError("a and b must be k-major")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2113 - 2115, Replace the runtime
assertion that validates k-major layout with an explicit exception: instead of
using assert real_a_stride[2] == 1 and real_b_stride[1] == 1, raise a ValueError
with the same descriptive message so the check remains active under python -O;
locate the check around the call to _get_bf16_3d_shape_stride and update the
validation (referring to variables real_a_stride and real_b_stride) to raise
ValueError("a and b must be k-major") when the condition fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1955-1966: The renamed override-shape helper functions must keep
their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions
that point the old names to the new implementations (e.g., define the previous
FP4/MXFP8 function names to return/call
build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers
referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.
- Around line 2321-2365: The override-shape executor
execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor
shapes/strides; update it to use the same normalization used by the BF16 path by
calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale,
c_final) and use the returned normalized shape and stride values when building
override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2115: Replace the runtime assertion that validates k-major
layout with an explicit exception: instead of using assert real_a_stride[2] == 1
and real_b_stride[1] == 1, raise a ValueError with the same descriptive message
so the check remains active under python -O; locate the check around the call to
_get_bf16_3d_shape_stride and update the validation (referring to variables
real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major")
when the condition fails.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6aab590b-9860-40d0-825b-b2ff37180d78

📥 Commits

Reviewing files that changed from the base of the PR and between 368656f and da99fe3.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

@yanqinz2
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !470 has been created, and the CI pipeline #47213607 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47213607: 12/20 passed

@yanqinz2 yanqinz2 force-pushed the yanqinz/dynamic-shape-unified-api branch from a0ff2b2 to fb8afc3 Compare March 31, 2026 17:11
a, b, alpha, out_dtype, block_size, use_nvfp4
)

execute_cudnn_gemm_fp4_graph_override_shape(
Copy link
Copy Markdown
Collaborator

@dhiraj113 dhiraj113 Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these functions be prefixed with _ to signifiy internal only use?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since execute_cudnn_gemm_fp4_graph can be externally used as well, I would assume execute_cudnn_gemm_fp4_graph_override_shape to be externally useable?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only things like mm_fp4are part of the API, everything else is internal. If you look at the analogous function _cudnn_gemm_fp4, it has a _ prefix to show it is internal.

Copy link
Copy Markdown
Collaborator

@dhiraj113 dhiraj113 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add explicit description to the MR clearly describing the changes that have been made.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants