misc: Add runtime validation for plan/run consistency in BatchMLAPagedAttentionWrapper#2246
misc: Add runtime validation for plan/run consistency in BatchMLAPagedAttentionWrapper#2246bkryu wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
WalkthroughThis change adds plan-time validation to the MLA module to ensure consistency between declared and actual query counts. The Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🪛 Ruff (0.14.8)flashinfer/mla.py443-447: Avoid specifying long messages outside the exception class (TRY003) 449-453: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a runtime validation check in BatchMLAPagedAttentionWrapper to ensure consistency between the plan() and run() methods. Specifically, it verifies that the number of query tokens in q_nope and q_pe tensors passed to run() matches the total number of queries derived from qo_indptr in plan(). This is a valuable addition as it converts potential cryptic CUDA memory errors into clear ValueError exceptions, improving debuggability. The implementation is correct. I've suggested a small refactoring to combine the validation checks, which reduces code duplication and provides a more comprehensive error message.
| if q_nope.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_nope.shape[0] ({q_nope.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) | ||
| if q_pe.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_pe.shape[0] ({q_pe.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) |
There was a problem hiding this comment.
These two checks for q_nope and q_pe can be combined to reduce code duplication. A combined check also provides a more comprehensive error message if both tensors have incorrect shapes, which improves the debugging experience.
if q_nope.shape[0] != self._total_num_queries or q_pe.shape[0] != self._total_num_queries:
raise ValueError(
f"Total number of query tokens mismatch. Expected {self._total_num_queries} from plan(), "
f"but got q_nope={q_nope.shape[0]} and q_pe={q_pe.shape[0]} in run()."
)There was a problem hiding this comment.
This nit comment is valid.
There was a problem hiding this comment.
I am not sure whether to assert q_nope.shape[0] >= self._total_num_queries orq_nope.shape[0] != self._total_num_queries because framework possibly does padding to query. Here is a valid example from EAGLE run where the actual batch 87, but seq_lens is padded to 90.
[INIT METADATA DEBUG] init_forward_metadata called
forward_mode = 6
forward_mode.is_decode_or_idle() = False
forward_mode.is_draft_extend() = True
forward_mode.is_target_verify() = False
batch_size = 87
seq_lens.shape = torch.Size([87])
spec_info type = EagleDraftInput
forward_batch.seq_lens = tensor([1275, 1305, 1302, 1331, 1293, 1277, 1288, 1301, 1290, 1313, 1294, 1287,
1275, 1290, 1272, 1307, 1292, 1284, 1306, 1286, 1272, 1317, 1277, 1281,
1332, 1287, 1353, 1309, 1329, 1319, 1276, 1274, 1272, 1267, 1285, 1374,
1365, 1331, 1310, 1297, 1297, 1269, 1291, 1303, 1282, 1365, 1279, 1278,
1279, 1318, 1336, 1294, 1296, 1299, 1324, 1298, 1301, 1348, 1268, 1327,
1298, 1312, 1275, 1288, 1281, 1294, 1317, 1287, 1307, 1272, 1295, 1283,
1293, 1325, 1323, 1295, 1297, 1330, 1297, 1306, 1312, 1286, 1277, 1291,
1277, 1301, 1279, 1306, 1301, 1], device='cuda:0')
Total q tokens = 180
For the bug in #2236, it's clearly q_nope.shape[0] < qo_indptr[-1] to cause the IMA.
There was a problem hiding this comment.
Given that the q_nope.shape[0] must equal self._total_num_queries assumption is not true, I am thinking that this PR is maybe not necessary and am leaning towards closing the PR without merging.
There was a problem hiding this comment.
This check could be handled at the framework level, and closing this PR looks reasonable to me.
|
/bot run |
|
[CANCELING] Pipeline #40510455: canceled |
| self._sm_scale = sm_scale | ||
| self._use_profiler = use_profiler | ||
| # Store total query count for validation in run() | ||
| self._total_num_queries = int(qo_indptr_host[-1].item()) |
There was a problem hiding this comment.
I have some concern if qo_indptr will be a device tensor in the future (because [-1].item() will call a cuda sync in that case). But considering we expect user to pass a host-side tensor at the moment I'm good with this.
There was a problem hiding this comment.
I initially thought the same, but a few lines above, we have:
qo_indptr_host = qo_indptr.to("cpu")
kv_indptr_host = kv_indptr.to("cpu")
kv_len_arr_host = kv_len_arr.to("cpu")
which means that the plan() function already requires a CUDA sync and is therefore cannot be placed in CUDA graph, so I figured it is okay here
| if q_nope.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_nope.shape[0] ({q_nope.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) | ||
| if q_pe.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_pe.shape[0] ({q_pe.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) |
| The page table of the paged kv-cache, shape: ``[batch_size, num_pages]``. Required when ``backend`` is ``cutlass``. | ||
| """ | ||
|
|
||
| if self._backend != "cutlass": |
There was a problem hiding this comment.
Maybe check for cutlass backend also.
| if q_nope.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_nope.shape[0] ({q_nope.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) | ||
| if q_pe.shape[0] != self._total_num_queries: | ||
| raise ValueError( | ||
| f"q_pe.shape[0] ({q_pe.shape[0]}) does not match " | ||
| f"qo_indptr[-1] ({self._total_num_queries}) from plan(). " | ||
| f"The total number of query tokens must be consistent between plan() and run()." | ||
| ) |
There was a problem hiding this comment.
I am not sure whether to assert q_nope.shape[0] >= self._total_num_queries orq_nope.shape[0] != self._total_num_queries because framework possibly does padding to query. Here is a valid example from EAGLE run where the actual batch 87, but seq_lens is padded to 90.
[INIT METADATA DEBUG] init_forward_metadata called
forward_mode = 6
forward_mode.is_decode_or_idle() = False
forward_mode.is_draft_extend() = True
forward_mode.is_target_verify() = False
batch_size = 87
seq_lens.shape = torch.Size([87])
spec_info type = EagleDraftInput
forward_batch.seq_lens = tensor([1275, 1305, 1302, 1331, 1293, 1277, 1288, 1301, 1290, 1313, 1294, 1287,
1275, 1290, 1272, 1307, 1292, 1284, 1306, 1286, 1272, 1317, 1277, 1281,
1332, 1287, 1353, 1309, 1329, 1319, 1276, 1274, 1272, 1267, 1285, 1374,
1365, 1331, 1310, 1297, 1297, 1269, 1291, 1303, 1282, 1365, 1279, 1278,
1279, 1318, 1336, 1294, 1296, 1299, 1324, 1298, 1301, 1348, 1268, 1327,
1298, 1312, 1275, 1288, 1281, 1294, 1317, 1287, 1307, 1272, 1295, 1283,
1293, 1325, 1323, 1295, 1297, 1330, 1297, 1306, 1312, 1286, 1277, 1291,
1277, 1301, 1279, 1306, 1301, 1], device='cuda:0')
Total q tokens = 180
For the bug in #2236, it's clearly q_nope.shape[0] < qo_indptr[-1] to cause the IMA.
|
Closing the PR as the check can and might be best handled by the framework. |
📌 Description
Adds a check in
run()to validate thatq_nope.shape[0]andq_pe.shape[0]matchqo_indptr[-1]from plan(). Previously, a mismatch between plan and run configurations could cause CUDA illegal memory access errors, making debugging difficult. This change converts such errors into a clear ValueError with an actionable message.The check applies to fa2/fa3 backends only, as the cutlass backend doesn't use
plan().🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.