fix(moe): make hidden_states_scale optional in trtllm_fp4_block_scale_moe#2906
fix(moe): make hidden_states_scale optional in trtllm_fp4_block_scale_moe#2906kuttivicky wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
…_moe `hidden_states_scale` was typed as `Optional[torch.Tensor]` but had no default value, causing a `TypeError` when callers omitted it with bf16 hidden_states (issue flashinfer-ai#2657 error 2). Move the parameter to the end of both the public API and the registered op so it properly defaults to None. Add a Python-level validation check: when hidden_states dtype is bfloat16, gemm1_weights must be in MxE2m1 (MXFP4) format — the underlying kernel only supports that combination. Previously the C++ assertion fired with an opaque error; now a clear ValueError is raised with guidance to use trtllm_bf16_moe for all-BF16 cases. Update docstring to document the parameter's position and constraints. Fixes flashinfer-ai#2657 AI-assisted
📝 WalkthroughWalkthroughModified MoE function signatures to make Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the trtllm_fp4_block_scale_moe operations to make the hidden_states_scale parameter optional and introduces a validation check to ensure that Bfloat16 hidden states are used exclusively with MxE2m1 weights. It also includes a new test suite to verify these changes and prevent regressions. Feedback was provided regarding a typo in the docstring for hidden_states_scale, where the divisor for mxfp4 should be 32 rather than 16.
| tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) | ||
| output (Optional[torch.Tensor]): shape [seq_len, hidden_size] | ||
| Optional inplace output tensor. | ||
| hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)] |
There was a problem hiding this comment.
The shape description for hidden_states_scale seems to have a typo regarding the scale factor divisor for mxfp4.
Based on the implementation of deduce_trtllm_gen_tensor_dtype, the scale dimension divisor is:
16fornvfp4(DtypeTrtllmGen.E2m1)32formxfp4(DtypeTrtllmGen.MxE2m1)
The docstring currently states 16 if mxfp4, which appears to be incorrect. For clarity, I suggest updating the shape description to reflect this, for example: shape [seq_len, hidden_size // D] where D is 32 for mxfp8/mxfp4 and 16 for nvfp4.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2825-2828: Remove the stale duplicate doc entry for the parameter
"hidden_states_scale" so it is only documented once (keep the correct block that
explains shape, when it's required for mxfp8/nvfp4 and when it must be None for
bfloat16/MxE2m1). Locate the earlier/duplicate "hidden_states_scale" docstring
entry (the older block near the other parameter docs) and delete it, ensuring
the remaining docstring under the function/class (referencing
"hidden_states_scale") accurately describes dtype requirements and shape.
- Line 1869: The fake op _fake_trtllm_fp4_block_scale_moe must have the same
signature as the real op trtllm_fp4_block_scale_moe_op: move hidden_states_scale
to be an optional trailing parameter with default None (so it is not a required
positional arg), and then update the routed wrapper
trtllm_fp4_block_scale_routed_moe to call the op with hidden_states_scale passed
as a keyword argument (hidden_states_scale=hidden_states_scale) instead of as
the 6th positional argument so positional binding remains correct for all
subsequent parameters.
In `@tests/moe/test_bf16_hidden_states_moe.py`:
- Around line 37-40: The skip guard in skip_if_not_sm100() is too broad because
checking cc[0] == 10 admits SM103 variants; update the function to use the repo
helper is_sm100a_supported() (or compare the full compute capability tuple from
get_compute_capability(torch.device("cuda")) against the exact SM100 tuple) so
only true SM100 (Blackwell) GPUs run the test; locate the skip_if_not_sm100
function and replace the cc[0] check with a call to is_sm100a_supported() (or a
full-tuple equality check) and call pytest.skip when that helper returns False.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f9cbeac8-6393-469a-a389-ec6399a09b32
📒 Files selected for processing (2)
flashinfer/fused_moe/core.pytests/moe/test_bf16_hidden_states_moe.py
| activation_type: int = ActivationType.Swiglu.value, | ||
| output: Optional[torch.Tensor] = None, | ||
| tune_max_num_tokens: int = 8192, | ||
| hidden_states_scale: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path
path = Path("flashinfer/fused_moe/core.py")
tree = ast.parse(path.read_text())
def find_fn(name):
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == name:
return node
raise RuntimeError(f"{name} not found")
def dump_signature(node):
args = [arg.arg for arg in node.args.args]
defaults = ["<required>"] * (len(args) - len(node.args.defaults)) + [
ast.unparse(default) for default in node.args.defaults
]
print(f"\n{node.name} @ line {node.lineno}")
for arg, default in zip(args, defaults):
print(f" {arg} = {default}")
dump_signature(find_fn("trtllm_fp4_block_scale_moe_op"))
dump_signature(find_fn("_fake_trtllm_fp4_block_scale_moe"))
routed = find_fn("trtllm_fp4_block_scale_routed_moe")
for call in ast.walk(routed):
if isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute) and call.func.attr == "trtllm_fp4_block_scale_moe":
print("\ntrtllm_fp4_block_scale_routed_moe -> trtllm_fp4_block_scale_moe")
for idx, arg in enumerate(call.args[:8], 1):
print(f" positional arg {idx}: {ast.unparse(arg)}")
print(" keywords:", [kw.arg for kw in call.keywords])
break
PYRepository: flashinfer-ai/flashinfer
Length of output: 2421
Fix signature mismatch between real op, fake op, and routed wrapper.
The real op trtllm_fp4_block_scale_moe_op has hidden_states_scale as a trailing optional parameter (default None), but _fake_trtllm_fp4_block_scale_moe still declares it at position 6 (right after hidden_states). The routed wrapper trtllm_fp4_block_scale_routed_moe passes it positionally at position 6, which will cause every subsequent argument to be mis-bound when called through the real op. Per register_fake_op requirements, the fake op signature must exactly mirror the real op.
Update _fake_trtllm_fp4_block_scale_moe to move hidden_states_scale to the trailing position (with None default), and change the routed wrapper's call to use a keyword argument for hidden_states_scale.
Signature verification output
trtllm_fp4_block_scale_moe_op @ line 1837
...
hidden_states_scale = None
_fake_trtllm_fp4_block_scale_moe @ line 2030
...
hidden_states = <required>
hidden_states_scale = <required>
gemm1_weights = <required>
...
trtllm_fp4_block_scale_routed_moe call:
positional arg 6: hidden_states_scale
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` at line 1869, The fake op
_fake_trtllm_fp4_block_scale_moe must have the same signature as the real op
trtllm_fp4_block_scale_moe_op: move hidden_states_scale to be an optional
trailing parameter with default None (so it is not a required positional arg),
and then update the routed wrapper trtllm_fp4_block_scale_routed_moe to call the
op with hidden_states_scale passed as a keyword argument
(hidden_states_scale=hidden_states_scale) instead of as the 6th positional
argument so positional binding remains correct for all subsequent parameters.
| hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)] | ||
| Scale tensor for mxfp8 / nvfp4 hidden states. Required when hidden_states dtype is | ||
| mxfp8 or nvfp4 (uint8). Must be None when hidden_states dtype is bfloat16, in which | ||
| case gemm1_weights must be MxE2m1 (MXFP4) format. |
There was a problem hiding this comment.
Remove the stale earlier hidden_states_scale doc entry.
The older hidden_states_scale block around Line 2775 is still present, so this parameter is now documented twice with two different positions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` around lines 2825 - 2828, Remove the stale
duplicate doc entry for the parameter "hidden_states_scale" so it is only
documented once (keep the correct block that explains shape, when it's required
for mxfp8/nvfp4 and when it must be None for bfloat16/MxE2m1). Locate the
earlier/duplicate "hidden_states_scale" docstring entry (the older block near
the other parameter docs) and delete it, ensuring the remaining docstring under
the function/class (referencing "hidden_states_scale") accurately describes
dtype requirements and shape.
| def skip_if_not_sm100(): | ||
| cc = get_compute_capability(torch.device("cuda")) | ||
| if cc[0] not in [10]: | ||
| pytest.skip("trtllm_fp4_block_scale_moe requires SM100 (Blackwell) GPU") |
There was a problem hiding this comment.
Tighten the SM100 skip guard.
cc[0] == 10 also admits SM103, so this can run the regression on unsupported Blackwell variants. Please use the repo helper or compare the full capability tuple instead.
As per coding guidelines "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/moe/test_bf16_hidden_states_moe.py` around lines 37 - 40, The skip
guard in skip_if_not_sm100() is too broad because checking cc[0] == 10 admits
SM103 variants; update the function to use the repo helper is_sm100a_supported()
(or compare the full compute capability tuple from
get_compute_capability(torch.device("cuda")) against the exact SM100 tuple) so
only true SM100 (Blackwell) GPUs run the test; locate the skip_if_not_sm100
function and replace the cc[0] check with a call to is_sm100a_supported() (or a
full-tuple equality check) and call pytest.skip when that helper returns False.
hidden_states_scalewas typed asOptional[torch.Tensor]but had no default value, causing aTypeErrorwhen callers omitted it with bf16 hidden_states (issue #2657 error 2). Move the parameter to the end of both the public API and the registered op so it properly defaults to None.Add a Python-level validation check: when hidden_states dtype is bfloat16, gemm1_weights must be in MxE2m1 (MXFP4) format — the underlying kernel only supports that combination. Previously the C++ assertion fired with an opaque error; now a clear ValueError is raised with guidance to use trtllm_bf16_moe for all-BF16 cases.
Update docstring to document the parameter's position and constraints.
Fixes #2657
AI-assisted
📌 Description
trtllm_fp4_block_scale_moedocumentedhidden_states_scaleas optional but it had no default value, causing a TypeError when omitted with bf16 hidden_states. This moves the parameter to the end of both the public API and the registered op so it defaults to None, and adds a clear Python-level validation error when bf16 activations are paired with non-MxE2m1 weights.🔍 Related Issues
Fixes #2657
#2657
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
The signature change moves
hidden_states_scaleto a keyword argument at the end of the parameter list (matching the pattern already used in the inner registered op). Existing callers passing it positionally will need to update, but the previous behavior was broken anyway. Happy to adjust if the team prefers a different approach.Tests added in tests/moe/test_bf16_hidden_states_moe.py. The signature test runs without a GPU; the validation test requires SM100 and was not run locally due to hardware unavailability.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests