Skip to content

misc: Add runtime validation for plan/run consistency in BatchMLAPagedAttentionWrapper#2246

Closed
bkryu wants to merge 1 commit intoflashinfer-ai:mainfrom
bkryu:mla_batch_size_check
Closed

misc: Add runtime validation for plan/run consistency in BatchMLAPagedAttentionWrapper#2246
bkryu wants to merge 1 commit intoflashinfer-ai:mainfrom
bkryu:mla_batch_size_check

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Dec 19, 2025

📌 Description

Adds a check in run() to validate that q_nope.shape[0] and q_pe.shape[0] match qo_indptr[-1] from plan(). Previously, a mismatch between plan and run configurations could cause CUDA illegal memory access errors, making debugging difficult. This change converts such errors into a clear ValueError with an actionable message.

The check applies to fa2/fa3 backends only, as the cutlass backend doesn't use plan().

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced validation in multi-head latency-attention operations to detect inconsistencies between plan-time configuration and runtime input parameters. The system now provides explicit and descriptive error messages when parameter mismatches are detected, helping users identify and resolve configuration issues more efficiently. Validation applies selectively to non-cutlass backends while cutlass backend behavior remains unchanged.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 19, 2025

Walkthrough

This change adds plan-time validation to the MLA module to ensure consistency between declared and actual query counts. The plan() method now stores the total query count from index pointer data, while run() validates that actual input shapes match the plan-time count for non-cutlass backends.

Changes

Cohort / File(s) Summary
Query Count Validation
flashinfer/mla.py
plan() now initializes self._total_num_queries from the last value of qo_indptr_host to track total query count. run() adds validation for non-cutlass backends to verify consistency between plan-time and runtime query counts; raises ValueError on mismatch.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Straightforward validation logic with clear error messaging
  • Single file modification with localized scope
  • May want to verify the validation logic correctly handles edge cases for index pointer calculations

Suggested reviewers

  • yzh119
  • cyx-6

Poem

🐰 A rabbit hops through flashinfer's plan,
Counting queries with a careful scan—
Plan and run now dance in sync,
No more mismatches at the brink!
Validation checks keep logic sound,
Where consistency and safety are found. 🌟

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title is related to the changeset but uses overly broad 'misc' prefix that obscures the specific improvement; the actual change is a runtime validation enhancement, not miscellaneous work. Consider revising the title to remove the 'misc:' prefix and be more specific, such as 'Add runtime validation for plan/run consistency in BatchMLAPagedAttentionWrapper'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description adequately covers the changes, rationale, and scope; however, the Tests section checklist is not marked complete despite being part of the template requirements.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a301a1 and 5164e08.

📒 Files selected for processing (1)
  • flashinfer/mla.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
flashinfer/mla.py

443-447: Avoid specifying long messages outside the exception class

(TRY003)


449-453: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/mla.py (2)

349-350: LGTM! Clear validation setup.

Storing the total query count from qo_indptr_host[-1] is the correct approach for later validation. The comment clearly explains the purpose, and converting to a Python int with .item() is appropriate for the comparison in run().


439-454: Effective input validation prevents plan/run consistency issues.

The validation logic correctly ensures that q_nope and q_pe shapes match _total_num_queries for non-cutlass backends, which guards against the mismatches that can cause CUDA illegal memory access errors. The backend check properly skips validation for cutlass, which doesn't use the plan/run pattern. Error messages are clear and help developers quickly identify configuration mismatches. The TRY003 hint about long exception messages is minor—keeping them inline aids debugging in practice.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the BatchMLAPagedAttentionWrapper by introducing a critical runtime validation mechanism. It ensures consistency between the planning and execution phases, specifically verifying that the number of query tokens remains unchanged. This proactive check aims to prevent obscure CUDA memory errors, replacing them with explicit and informative ValueError messages, thereby significantly improving the debugging experience and overall stability of the system for fa2/fa3 backends.

Highlights

  • Runtime Validation: Implemented a new runtime validation check within the run() method of BatchMLAPagedAttentionWrapper.
  • Consistency Check: Validates that the number of query tokens (q_nope.shape[0] and q_pe.shape[0]) provided to run() matches the total query count established during the plan() phase (qo_indptr[-1]).
  • Error Handling Improvement: Converts potential CUDA illegal memory access errors, which were previously difficult to debug, into clear ValueError exceptions with actionable messages.
  • Backend Specificity: This validation is specifically applied to fa2/fa3 backends, as the cutlass backend does not utilize the plan() method.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu requested a review from wenscarl December 19, 2025 17:33
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a runtime validation check in BatchMLAPagedAttentionWrapper to ensure consistency between the plan() and run() methods. Specifically, it verifies that the number of query tokens in q_nope and q_pe tensors passed to run() matches the total number of queries derived from qo_indptr in plan(). This is a valuable addition as it converts potential cryptic CUDA memory errors into clear ValueError exceptions, improving debuggability. The implementation is correct. I've suggested a small refactoring to combine the validation checks, which reduces code duplication and provides a more comprehensive error message.

Comment on lines +442 to +453
if q_nope.shape[0] != self._total_num_queries:
raise ValueError(
f"q_nope.shape[0] ({q_nope.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
if q_pe.shape[0] != self._total_num_queries:
raise ValueError(
f"q_pe.shape[0] ({q_pe.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These two checks for q_nope and q_pe can be combined to reduce code duplication. A combined check also provides a more comprehensive error message if both tensors have incorrect shapes, which improves the debugging experience.

            if q_nope.shape[0] != self._total_num_queries or q_pe.shape[0] != self._total_num_queries:
                raise ValueError(
                    f"Total number of query tokens mismatch. Expected {self._total_num_queries} from plan(), "
                    f"but got q_nope={q_nope.shape[0]} and q_pe={q_pe.shape[0]} in run()."
                )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This nit comment is valid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether to assert q_nope.shape[0] >= self._total_num_queries orq_nope.shape[0] != self._total_num_queries because framework possibly does padding to query. Here is a valid example from EAGLE run where the actual batch 87, but seq_lens is padded to 90.

[INIT METADATA DEBUG] init_forward_metadata called
  forward_mode = 6
  forward_mode.is_decode_or_idle() = False
  forward_mode.is_draft_extend() = True
  forward_mode.is_target_verify() = False
  batch_size = 87
  seq_lens.shape = torch.Size([87])
  spec_info type = EagleDraftInput
  forward_batch.seq_lens = tensor([1275, 1305, 1302, 1331, 1293, 1277, 1288, 1301, 1290, 1313, 1294, 1287,
        1275, 1290, 1272, 1307, 1292, 1284, 1306, 1286, 1272, 1317, 1277, 1281,
        1332, 1287, 1353, 1309, 1329, 1319, 1276, 1274, 1272, 1267, 1285, 1374,
        1365, 1331, 1310, 1297, 1297, 1269, 1291, 1303, 1282, 1365, 1279, 1278,
        1279, 1318, 1336, 1294, 1296, 1299, 1324, 1298, 1301, 1348, 1268, 1327,
        1298, 1312, 1275, 1288, 1281, 1294, 1317, 1287, 1307, 1272, 1295, 1283,
        1293, 1325, 1323, 1295, 1297, 1330, 1297, 1306, 1312, 1286, 1277, 1291,
        1277, 1301, 1279, 1306, 1301,    1], device='cuda:0')
  Total q tokens = 180

For the bug in #2236, it's clearly q_nope.shape[0] < qo_indptr[-1] to cause the IMA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the q_nope.shape[0] must equal self._total_num_queries assumption is not true, I am thinking that this PR is maybe not necessary and am leaning towards closing the PR without merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could be handled at the framework level, and closing this PR looks reasonable to me.

@bkryu bkryu self-assigned this Dec 19, 2025
@bkryu
Copy link
Collaborator Author

bkryu commented Dec 19, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !206 has been created, and the CI pipeline #40510455 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #40510455: canceled

self._sm_scale = sm_scale
self._use_profiler = use_profiler
# Store total query count for validation in run()
self._total_num_queries = int(qo_indptr_host[-1].item())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concern if qo_indptr will be a device tensor in the future (because [-1].item() will call a cuda sync in that case). But considering we expect user to pass a host-side tensor at the moment I'm good with this.

Copy link
Collaborator Author

@bkryu bkryu Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially thought the same, but a few lines above, we have:

        qo_indptr_host = qo_indptr.to("cpu")
        kv_indptr_host = kv_indptr.to("cpu")
        kv_len_arr_host = kv_len_arr.to("cpu")

which means that the plan() function already requires a CUDA sync and is therefore cannot be placed in CUDA graph, so I figured it is okay here

Comment on lines +442 to +453
if q_nope.shape[0] != self._total_num_queries:
raise ValueError(
f"q_nope.shape[0] ({q_nope.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
if q_pe.shape[0] != self._total_num_queries:
raise ValueError(
f"q_pe.shape[0] ({q_pe.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``.
"""

if self._backend != "cutlass":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check for cutlass backend also.

Comment on lines +442 to +453
if q_nope.shape[0] != self._total_num_queries:
raise ValueError(
f"q_nope.shape[0] ({q_nope.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
if q_pe.shape[0] != self._total_num_queries:
raise ValueError(
f"q_pe.shape[0] ({q_pe.shape[0]}) does not match "
f"qo_indptr[-1] ({self._total_num_queries}) from plan(). "
f"The total number of query tokens must be consistent between plan() and run()."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether to assert q_nope.shape[0] >= self._total_num_queries orq_nope.shape[0] != self._total_num_queries because framework possibly does padding to query. Here is a valid example from EAGLE run where the actual batch 87, but seq_lens is padded to 90.

[INIT METADATA DEBUG] init_forward_metadata called
  forward_mode = 6
  forward_mode.is_decode_or_idle() = False
  forward_mode.is_draft_extend() = True
  forward_mode.is_target_verify() = False
  batch_size = 87
  seq_lens.shape = torch.Size([87])
  spec_info type = EagleDraftInput
  forward_batch.seq_lens = tensor([1275, 1305, 1302, 1331, 1293, 1277, 1288, 1301, 1290, 1313, 1294, 1287,
        1275, 1290, 1272, 1307, 1292, 1284, 1306, 1286, 1272, 1317, 1277, 1281,
        1332, 1287, 1353, 1309, 1329, 1319, 1276, 1274, 1272, 1267, 1285, 1374,
        1365, 1331, 1310, 1297, 1297, 1269, 1291, 1303, 1282, 1365, 1279, 1278,
        1279, 1318, 1336, 1294, 1296, 1299, 1324, 1298, 1301, 1348, 1268, 1327,
        1298, 1312, 1275, 1288, 1281, 1294, 1317, 1287, 1307, 1272, 1295, 1283,
        1293, 1325, 1323, 1295, 1297, 1330, 1297, 1306, 1312, 1286, 1277, 1291,
        1277, 1301, 1279, 1306, 1301,    1], device='cuda:0')
  Total q tokens = 180

For the bug in #2236, it's clearly q_nope.shape[0] < qo_indptr[-1] to cause the IMA.

@bkryu
Copy link
Collaborator Author

bkryu commented Dec 22, 2025

Closing the PR as the check can and might be best handled by the framework.

@bkryu bkryu closed this Dec 22, 2025
@bkryu bkryu deleted the mla_batch_size_check branch December 30, 2025 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants