Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRenamed cuDNN GEMM override-shape FP4/MXFP8 APIs, added optional Changes
Sequence Diagram(s)sequenceDiagram
participant Runner as Runner
participant Builder as cudnn Graph Builder
participant Graph as cudnn Graph
participant GPU as GPU / Workspace
participant Tensors as Prepared Tensors
Runner->>Builder: build_cudnn_gemm_*_graph_override_shape(params, policy?)
Builder->>Graph: create graph object
Builder->>Graph: build_plans(policy)
Graph-->>Runner: graph (cached by m / power-of-two)
Runner->>Tensors: expand packed FP4/MXFP8/BF16 -> compute override shapes & strides
Runner->>GPU: query graph.get_workspace_size()
alt workspace too small
GPU-->>Runner: reallocate workspace_buffer
end
Runner->>Graph: execute_override_shape(graph, tensors, workspace_buffer)
Graph->>GPU: run GEMM with override shapes/strides
GPU-->>Runner: result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the cuDNN GEMM implementation to provide consistent dynamic-shape (override-shape) support across FP4, MXFP8, and BF16 data types. Key changes include renaming functions for better naming consistency, adding a policy parameter to graph builders, and integrating override-shape logic into the TunableRunner classes for BF16 and FP4. The review feedback focuses on improving the efficiency of workspace buffer handling by suggesting that the code raise a ValueError for undersized buffers instead of performing local re-allocations, and recommends refactoring duplicated logic within the runner classes to improve maintainability.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/gemm/test_cudnn_override_shape.py (1)
17-27: Import the public helpers throughflashinfer.gemmin this test.Right now this bypasses
flashinfer.gemm, so the test still passes even if the package re-export layer regresses. Since this PR changes that surface, the test should exercise it.🧪 Minimal import split
-from flashinfer.gemm.gemm_base import ( - CUDNN_AVAILABLE, - build_cudnn_gemm_bf16_graph_override_shape, - execute_cudnn_gemm_bf16_graph_override_shape, - build_cudnn_gemm_fp4_graph_override_shape, - execute_cudnn_gemm_fp4_graph_override_shape, - build_cudnn_gemm_mxfp8_graph_override_shape, - execute_cudnn_gemm_mxfp8_graph_override_shape, - is_cudnn_override_shape_available, - _calculate_block_scale_dims, -) +from flashinfer.gemm import ( + build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_gemm_fp4_graph_override_shape, + execute_cudnn_gemm_fp4_graph_override_shape, + build_cudnn_gemm_mxfp8_graph_override_shape, + execute_cudnn_gemm_mxfp8_graph_override_shape, + is_cudnn_override_shape_available, +) +from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE, _calculate_block_scale_dims🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_cudnn_override_shape.py` around lines 17 - 27, The test directly imports helpers from flashinfer.gemm.gemm_base instead of exercising the package re-export layer; update the import to import the public helpers from flashinfer.gemm (e.g. import CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape, build_cudnn_gemm_fp4_graph_override_shape, execute_cudnn_gemm_fp4_graph_override_shape, build_cudnn_gemm_mxfp8_graph_override_shape, execute_cudnn_gemm_mxfp8_graph_override_shape, is_cudnn_override_shape_available, _calculate_block_scale_dims) via from flashinfer.gemm import <symbols> so the test fails if the package re-export surface regresses.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 25-28: Restore backward-compatible aliases for the renamed
override-shape exports by reintroducing the old names as simple assignments to
the new symbols: for example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2121-2124: Replace the runtime assertion with explicit input
validation that raises a ValueError: instead of using "assert real_a_stride[2]
== 1 and real_b_stride[1] == 1", check those conditions and raise ValueError
with a clear message (e.g., "a and b must be k-major") so invalid
caller-provided strides cannot slip through when Python assertions are disabled;
update the code around _get_bf16_3d_shape_stride and the
real_a_stride/real_b_stride checks accordingly.
---
Nitpick comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 17-27: The test directly imports helpers from
flashinfer.gemm.gemm_base instead of exercising the package re-export layer;
update the import to import the public helpers from flashinfer.gemm (e.g. import
CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 463f5a46-a140-49a0-8802-735d5fba51af
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gemm/gemm_base.py (2)
1955-1966:⚠️ Potential issue | 🟠 MajorKeep deprecated aliases for the renamed override-shape helpers.
These functions are part of the exported
flashinfer.gemmsurface, so renaming them without keeping the previous symbols is a hard import break for downstream users. Please keep the old FP4/MXFP8 names as aliases for at least one release cycle, or explicitly treat this PR as a breaking API change.Also applies to: 2090-2100, 2209-2219, 2321-2330
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1955 - 1966, The renamed override-shape helper functions must keep their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions that point the old names to the new implementations (e.g., define the previous FP4/MXFP8 function names to return/call build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers referenced at the other ranges) and emit a warnings.warn(..., DeprecationWarning) when those aliases are called so downstream code keeps working for one release cycle while signaling the rename; ensure the alias names exactly match the previous exported symbols and reference the new functions (use the new function names from this file such as build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation message indicating the new name.
2321-2365:⚠️ Potential issue | 🟠 MajorApply shape normalization to MXFP8 override-shape executor for consistency.
execute_cudnn_gemm_mxfp8_graph_override_shape()forwards raw tensor shapes/strides, whereasexecute_cudnn_gemm_bf16_graph_override_shape()uses_get_bf16_3d_shape_stride()to normalize 2D inputs to 3D. Even though MXFP8 currently enforces 3D inputs, the inconsistency should be resolved by using the same normalization pattern:Suggested normalization
def execute_cudnn_gemm_mxfp8_graph_override_shape( graph, a, b, @@ ): """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides.""" + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + batch = a_shape[0] + a_descale_shape, a_descale_stride = _expand_block_scale_tensor_shape( + a_descale, batch + ) + b_descale_shape, b_descale_stride = _expand_block_scale_tensor_shape( + b_descale, batch + ) + c_shape, c_stride = _get_bf16_3d_shape_stride(c_final) + variant_pack = { UIDs.A_UID.value: a, UIDs.B_UID.value: b, UIDs.BLOCK_DESCALE_A_UID.value: a_descale, @@ override_shapes = [ - list(a.shape), - list(b.shape), - list(a_descale.shape), - list(b_descale.shape), - list(c_final.shape), + list(a_shape), + list(b_shape), + list(a_descale_shape), + list(b_descale_shape), + list(c_shape), ] override_strides = [ - list(a.stride()), - list(b.stride()), - list(a_descale.stride()), - list(b_descale.stride()), - list(c_final.stride()), + list(a_stride), + list(b_stride), + list(a_descale_stride), + list(b_descale_stride), + list(c_stride), ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2321 - 2365, The override-shape executor execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor shapes/strides; update it to use the same normalization used by the BF16 path by calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale, c_final) and use the returned normalized shape and stride values when building override_shapes and override_strides; ensure you replace the direct list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride results so the MXFP8 override-shape logic is consistent with execute_cudnn_gemm_bf16_graph_override_shape.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
2113-2115:⚠️ Potential issue | 🟡 MinorUse
ValueErrorinstead ofassertfor the k-major guard.Line 2115 is validating caller input. Under
python -O, thatassertdisappears and invalid layouts can slip through to cuDNN.🛠️ Safer runtime validation
- assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major" + if real_a_stride[2] != 1 or real_b_stride[1] != 1: + raise ValueError("a and b must be k-major")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2113 - 2115, Replace the runtime assertion that validates k-major layout with an explicit exception: instead of using assert real_a_stride[2] == 1 and real_b_stride[1] == 1, raise a ValueError with the same descriptive message so the check remains active under python -O; locate the check around the call to _get_bf16_3d_shape_stride and update the validation (referring to variables real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major") when the condition fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1955-1966: The renamed override-shape helper functions must keep
their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions
that point the old names to the new implementations (e.g., define the previous
FP4/MXFP8 function names to return/call
build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers
referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.
- Around line 2321-2365: The override-shape executor
execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor
shapes/strides; update it to use the same normalization used by the BF16 path by
calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale,
c_final) and use the returned normalized shape and stride values when building
override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2115: Replace the runtime assertion that validates k-major
layout with an explicit exception: instead of using assert real_a_stride[2] == 1
and real_b_stride[1] == 1, raise a ValueError with the same descriptive message
so the check remains active under python -O; locate the check around the call to
_get_bf16_3d_shape_stride and update the validation (referring to variables
real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major")
when the condition fails.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6aab590b-9860-40d0-825b-b2ff37180d78
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
|
/bot run |
|
[FAILED] Pipeline #47213607: 12/20 passed |
a0ff2b2 to
fb8afc3
Compare
| a, b, alpha, out_dtype, block_size, use_nvfp4 | ||
| ) | ||
|
|
||
| execute_cudnn_gemm_fp4_graph_override_shape( |
There was a problem hiding this comment.
Should these functions be prefixed with _ to signifiy internal only use?
There was a problem hiding this comment.
Since execute_cudnn_gemm_fp4_graph can be externally used as well, I would assume execute_cudnn_gemm_fp4_graph_override_shape to be externally useable?
There was a problem hiding this comment.
I think only things like mm_fp4are part of the API, everything else is internal. If you look at the analogous function _cudnn_gemm_fp4, it has a _ prefix to show it is internal.
dhiraj113
left a comment
There was a problem hiding this comment.
Could you please add explicit description to the MR clearly describing the changes that have been made.
📌 Description
Add cudnn override shape support for bf16 and fp4 gemm.
Current bf16/nvfp4 gemm cudnn backend uses static shape approach, which rebuild the execution plan every time a new shape is encountered, which introduces tons of compilation overhead and repeated cached graphs.
This PR introduces override shape support and makes it as default option for cudnn backend, which reuse the pre-built cudnn graph and execution plans during the autotuning phase for the following problem sizes encountered, avoid repeated graph/execution plan caching and compilation overhead.
🔍 Related Issues
https://nvbugspro.nvidia.com/bug/5539146
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Refactor
New Features
Bug Fixes
Tests