Skip to content

fix(moe): make hidden_states_scale optional in trtllm_fp4_block_scale_moe#2906

Open
kuttivicky wants to merge 1 commit intoflashinfer-ai:mainfrom
kuttivicky:fix-bf16-moe-issue
Open

fix(moe): make hidden_states_scale optional in trtllm_fp4_block_scale_moe#2906
kuttivicky wants to merge 1 commit intoflashinfer-ai:mainfrom
kuttivicky:fix-bf16-moe-issue

Conversation

@kuttivicky
Copy link
Copy Markdown

@kuttivicky kuttivicky commented Mar 28, 2026

hidden_states_scale was typed as Optional[torch.Tensor] but had no default value, causing a TypeError when callers omitted it with bf16 hidden_states (issue #2657 error 2). Move the parameter to the end of both the public API and the registered op so it properly defaults to None.

Add a Python-level validation check: when hidden_states dtype is bfloat16, gemm1_weights must be in MxE2m1 (MXFP4) format — the underlying kernel only supports that combination. Previously the C++ assertion fired with an opaque error; now a clear ValueError is raised with guidance to use trtllm_bf16_moe for all-BF16 cases.

Update docstring to document the parameter's position and constraints.

Fixes #2657

AI-assisted

📌 Description

trtllm_fp4_block_scale_moe documented hidden_states_scale as optional but it had no default value, causing a TypeError when omitted with bf16 hidden_states. This moves the parameter to the end of both the public API and the registered op so it defaults to None, and adds a clear Python-level validation error when bf16 activations are paired with non-MxE2m1 weights.

🔍 Related Issues

Fixes #2657
#2657

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

The signature change moves hidden_states_scale to a keyword argument at the end of the parameter list (matching the pattern already used in the inner registered op). Existing callers passing it positionally will need to update, but the previous behavior was broken anyway. Happy to adjust if the team prefers a different approach.

Tests added in tests/moe/test_bf16_hidden_states_moe.py. The signature test runs without a GPU; the validation test requires SM100 and was not run locally due to hardware unavailability.

Summary by CodeRabbit

Release Notes

  • New Features

    • FP4 block-scale MoE operations now support optional hidden-state scaling parameters.
  • Bug Fixes

    • Added dtype compatibility validation for FP4 MoE operations with BF16 hidden states.
  • Tests

    • Added regression test suite for BF16 hidden-states MoE support.

…_moe

`hidden_states_scale` was typed as `Optional[torch.Tensor]` but had no
default value, causing a `TypeError` when callers omitted it with bf16
hidden_states (issue flashinfer-ai#2657 error 2). Move the parameter to the end of
both the public API and the registered op so it properly defaults to
None.

Add a Python-level validation check: when hidden_states dtype is
bfloat16, gemm1_weights must be in MxE2m1 (MXFP4) format — the
underlying kernel only supports that combination. Previously the C++
assertion fired with an opaque error; now a clear ValueError is raised
with guidance to use trtllm_bf16_moe for all-BF16 cases.

Update docstring to document the parameter's position and constraints.

Fixes flashinfer-ai#2657

AI-assisted
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

📝 Walkthrough

Walkthrough

Modified MoE function signatures to make hidden_states_scale parameter optional with default None value and added dtype compatibility validation. Added regression tests for BF16 hidden states support with MoE operations.

Changes

Cohort / File(s) Summary
Core API Updates
flashinfer/fused_moe/core.py
Made hidden_states_scale parameter optional with default value None in trtllm_fp4_block_scale_moe_op and trtllm_fp4_block_scale_moe functions. Added dtype compatibility check to raise ValueError when BF16 hidden states are used with non-MxE2m1 weights.
Regression Tests
tests/moe/test_bf16_hidden_states_moe.py
New test module for SM100-only validation, including parameter signature regression test and runtime dtype validation test ensuring correct error handling when BF16 hidden states are used with incompatible weight formats.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested labels

ready, op: moe

Suggested reviewers

  • aleozlx
  • yzh119
  • jiahanc
  • nv-yunzheq
  • IwakuraRein

Poem

🐰 A hop, a skip, a default None,
BF16 hidden states—now they're fun!
Scales optional, validation tight,
MoE kernels running oh so right! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix(moe): make hidden_states_scale optional in trtllm_fp4_block_scale_moe' clearly and concisely describes the main change—making hidden_states_scale optional by adding a default value.
Description check ✅ Passed The PR description adequately covers the problem statement, solution, and changes made. It includes linked issue reference, test additions, and reviewer notes explaining the signature change.
Linked Issues check ✅ Passed The PR fully addresses issue #2657 by making hidden_states_scale optional with a default value and adding Python-level validation for BF16 hidden_states paired with non-MxE2m1 weights.
Out of Scope Changes check ✅ Passed All changes are directly scoped to issue #2657: parameter signature updates in core.py and test additions to validate the fix. No unrelated changes detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the trtllm_fp4_block_scale_moe operations to make the hidden_states_scale parameter optional and introduces a validation check to ensure that Bfloat16 hidden states are used exclusively with MxE2m1 weights. It also includes a new test suite to verify these changes and prevent regressions. Feedback was provided regarding a typo in the docstring for hidden_states_scale, where the divisor for mxfp4 should be 32 rather than 16.

tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192)
output (Optional[torch.Tensor]): shape [seq_len, hidden_size]
Optional inplace output tensor.
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shape description for hidden_states_scale seems to have a typo regarding the scale factor divisor for mxfp4.

Based on the implementation of deduce_trtllm_gen_tensor_dtype, the scale dimension divisor is:

  • 16 for nvfp4 (DtypeTrtllmGen.E2m1)
  • 32 for mxfp4 (DtypeTrtllmGen.MxE2m1)

The docstring currently states 16 if mxfp4, which appears to be incorrect. For clarity, I suggest updating the shape description to reflect this, for example: shape [seq_len, hidden_size // D] where D is 32 for mxfp8/mxfp4 and 16 for nvfp4.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2825-2828: Remove the stale duplicate doc entry for the parameter
"hidden_states_scale" so it is only documented once (keep the correct block that
explains shape, when it's required for mxfp8/nvfp4 and when it must be None for
bfloat16/MxE2m1). Locate the earlier/duplicate "hidden_states_scale" docstring
entry (the older block near the other parameter docs) and delete it, ensuring
the remaining docstring under the function/class (referencing
"hidden_states_scale") accurately describes dtype requirements and shape.
- Line 1869: The fake op _fake_trtllm_fp4_block_scale_moe must have the same
signature as the real op trtllm_fp4_block_scale_moe_op: move hidden_states_scale
to be an optional trailing parameter with default None (so it is not a required
positional arg), and then update the routed wrapper
trtllm_fp4_block_scale_routed_moe to call the op with hidden_states_scale passed
as a keyword argument (hidden_states_scale=hidden_states_scale) instead of as
the 6th positional argument so positional binding remains correct for all
subsequent parameters.

In `@tests/moe/test_bf16_hidden_states_moe.py`:
- Around line 37-40: The skip guard in skip_if_not_sm100() is too broad because
checking cc[0] == 10 admits SM103 variants; update the function to use the repo
helper is_sm100a_supported() (or compare the full compute capability tuple from
get_compute_capability(torch.device("cuda")) against the exact SM100 tuple) so
only true SM100 (Blackwell) GPUs run the test; locate the skip_if_not_sm100
function and replace the cc[0] check with a call to is_sm100a_supported() (or a
full-tuple equality check) and call pytest.skip when that helper returns False.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f9cbeac8-6393-469a-a389-ec6399a09b32

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and fa19699.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py
  • tests/moe/test_bf16_hidden_states_moe.py

activation_type: int = ActivationType.Swiglu.value,
output: Optional[torch.Tensor] = None,
tune_max_num_tokens: int = 8192,
hidden_states_scale: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path

path = Path("flashinfer/fused_moe/core.py")
tree = ast.parse(path.read_text())

def find_fn(name):
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and node.name == name:
            return node
    raise RuntimeError(f"{name} not found")

def dump_signature(node):
    args = [arg.arg for arg in node.args.args]
    defaults = ["<required>"] * (len(args) - len(node.args.defaults)) + [
        ast.unparse(default) for default in node.args.defaults
    ]
    print(f"\n{node.name} @ line {node.lineno}")
    for arg, default in zip(args, defaults):
        print(f"  {arg} = {default}")

dump_signature(find_fn("trtllm_fp4_block_scale_moe_op"))
dump_signature(find_fn("_fake_trtllm_fp4_block_scale_moe"))

routed = find_fn("trtllm_fp4_block_scale_routed_moe")
for call in ast.walk(routed):
    if isinstance(call, ast.Call) and isinstance(call.func, ast.Attribute) and call.func.attr == "trtllm_fp4_block_scale_moe":
        print("\ntrtllm_fp4_block_scale_routed_moe -> trtllm_fp4_block_scale_moe")
        for idx, arg in enumerate(call.args[:8], 1):
            print(f"  positional arg {idx}: {ast.unparse(arg)}")
        print("  keywords:", [kw.arg for kw in call.keywords])
        break
PY

Repository: flashinfer-ai/flashinfer

Length of output: 2421


Fix signature mismatch between real op, fake op, and routed wrapper.

The real op trtllm_fp4_block_scale_moe_op has hidden_states_scale as a trailing optional parameter (default None), but _fake_trtllm_fp4_block_scale_moe still declares it at position 6 (right after hidden_states). The routed wrapper trtllm_fp4_block_scale_routed_moe passes it positionally at position 6, which will cause every subsequent argument to be mis-bound when called through the real op. Per register_fake_op requirements, the fake op signature must exactly mirror the real op.

Update _fake_trtllm_fp4_block_scale_moe to move hidden_states_scale to the trailing position (with None default), and change the routed wrapper's call to use a keyword argument for hidden_states_scale.

Signature verification output
trtllm_fp4_block_scale_moe_op @ line 1837
  ...
  hidden_states_scale = None

_fake_trtllm_fp4_block_scale_moe @ line 2030
  ...
  hidden_states = <required>
  hidden_states_scale = <required>
  gemm1_weights = <required>
  ...

trtllm_fp4_block_scale_routed_moe call:
  positional arg 6: hidden_states_scale
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` at line 1869, The fake op
_fake_trtllm_fp4_block_scale_moe must have the same signature as the real op
trtllm_fp4_block_scale_moe_op: move hidden_states_scale to be an optional
trailing parameter with default None (so it is not a required positional arg),
and then update the routed wrapper trtllm_fp4_block_scale_routed_moe to call the
op with hidden_states_scale passed as a keyword argument
(hidden_states_scale=hidden_states_scale) instead of as the 6th positional
argument so positional binding remains correct for all subsequent parameters.

Comment on lines +2825 to +2828
hidden_states_scale (Optional[torch.Tensor]): shape [seq_len, hidden_size // (32 if mxfp8, 16 if mxfp4)]
Scale tensor for mxfp8 / nvfp4 hidden states. Required when hidden_states dtype is
mxfp8 or nvfp4 (uint8). Must be None when hidden_states dtype is bfloat16, in which
case gemm1_weights must be MxE2m1 (MXFP4) format.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove the stale earlier hidden_states_scale doc entry.

The older hidden_states_scale block around Line 2775 is still present, so this parameter is now documented twice with two different positions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2825 - 2828, Remove the stale
duplicate doc entry for the parameter "hidden_states_scale" so it is only
documented once (keep the correct block that explains shape, when it's required
for mxfp8/nvfp4 and when it must be None for bfloat16/MxE2m1). Locate the
earlier/duplicate "hidden_states_scale" docstring entry (the older block near
the other parameter docs) and delete it, ensuring the remaining docstring under
the function/class (referencing "hidden_states_scale") accurately describes
dtype requirements and shape.

Comment on lines +37 to +40
def skip_if_not_sm100():
cc = get_compute_capability(torch.device("cuda"))
if cc[0] not in [10]:
pytest.skip("trtllm_fp4_block_scale_moe requires SM100 (Blackwell) GPU")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Tighten the SM100 skip guard.

cc[0] == 10 also admits SM103, so this can run the regression on unsupported Blackwell variants. Please use the repo helper or compare the full capability tuple instead.

As per coding guidelines "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_bf16_hidden_states_moe.py` around lines 37 - 40, The skip
guard in skip_if_not_sm100() is too broad because checking cc[0] == 10 admits
SM103 variants; update the function to use the repo helper is_sm100a_supported()
(or compare the full compute capability tuple from
get_compute_capability(torch.device("cuda")) against the exact SM100 tuple) so
only true SM100 (Blackwell) GPUs run the test; locate the skip_if_not_sm100
function and replace the cc[0] check with a call to is_sm100a_supported() (or a
full-tuple equality check) and call pytest.skip when that helper returns False.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BF16 hidden_states for trtllm_fp4_block_scale_moe

1 participant