[bugfix] Support MIXED forward mode in TBO splitter for DP attention#24241
Merged
[bugfix] Support MIXED forward mode in TBO splitter for DP attention#24241
Conversation
Combining --enable-dp-attention with --enable-mixed-chunk crashed at the first chunked-prefill iteration that absorbed a running decode. Under DP attention, prepare_mlp_sync_batch_raw unconditionally calls TboDPAttentionPreparer.prepare_all_gather, which calls compute_split_seq_index/compute_split_token_index on the local batch's forward_mode -- even when --enable-two-batch-overlap is off. Both functions raised NotImplementedError for ForwardMode.MIXED. Treat MIXED like EXTEND in both: mix_with_running already appends length-1 entries to extend_lens for the running decode reqs, so _split_extend_seqs and the cumulative-sum split fall through correctly. Validated on DeepSeek-Coder-V2-Lite, dp_size=2, tp_size=2, --chunked-prefill-size=256: gsm8k 200/128t scored 0.795 (>= 0.6 threshold) with mixed batches confirmed in scheduler logs. Note: --enable-two-batch-overlap + MIXED still unsupported -- the per- arch operations strategies in batch_overlap/operations_strategy.py also need MIXED handling. Out of scope here; non-TBO DP attention gates around this via can_run_tbo=False.
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the batch overlap logic to include ForwardMode.MIXED when computing split sequence and token indices. It also adds a new test case, TestDPAttentionMixedChunk, to verify DP attention with mixed chunking. The reviewer suggested using the in operator for enum comparisons to improve readability.
Collaborator
Author
|
/tag-and-rerun-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
`--enable-dp-attention` together with `--enable-mixed-chunk` crashes
deterministically at the first chunked-prefill iteration that merges a
running decode. The combination is silently broken — there's no validation
in `server_args.py` rejecting it, but no scheduler iteration with
`forward_mode = MIXED` ever survives long enough to reach the model.
Under DP attention, `prepare_mlp_sync_batch_raw` (`scheduler_dp_attn_mixin.py:187`)
unconditionally calls `TboDPAttentionPreparer.prepare_all_gather`, which
calls `compute_split_seq_index` on the local batch's `forward_mode` —
even when `--enable-two-batch-overlap` is off. That function (and
its companion `compute_split_token_index`) only handled
`EXTEND / TARGET_VERIFY / DECODE / IDLE / PREBUILT` and raised
`NotImplementedError` for `ForwardMode.MIXED`.
Repro on the failing combination:
```
File "python/sglang/srt/managers/scheduler_dp_attn_mixin.py", line 187, in prepare_mlp_sync_batch_raw
local_can_run_tbo, local_forward_mode = tbo_preparer.prepare_all_gather(local_batch)
File "python/sglang/srt/batch_overlap/two_batch_overlap.py", line 398, in prepare_all_gather
self.local_tbo_split_seq_index = compute_split_seq_index(
File "python/sglang/srt/batch_overlap/two_batch_overlap.py", line 94, in compute_split_seq_index
raise NotImplementedError()
NotImplementedError
```
Modifications
Treat `ForwardMode.MIXED` like `ForwardMode.EXTEND` in both
`compute_split_seq_index` and `compute_split_token_index`. After
`mix_with_running`, the running-decode reqs are appended to
`extend_lens` as length-1 entries, so `_split_extend_seqs` and the
cumulative-sum split logic apply unchanged.
Two-line fix in `python/sglang/srt/batch_overlap/two_batch_overlap.py`.
Added a regression test `TestDPAttentionMixedChunk` in
`test/registered/distributed/test_dp_attention.py` that exercises
`--enable-dp-attention --enable-mixed-chunk --chunked-prefill-size 256`
on the existing 2-GPU CI suite.
Out of scope: `--enable-two-batch-overlap` + `--enable-mixed-chunk`
is still unsupported. `OperationsStrategy.init_new_tbo` →
`compute_moe*_layer_operations_strategy_tbo` in
`batch_overlap/operations_strategy.py` (lines 84, 163, 240) also raise
`NotImplementedError` on MIXED. Non-TBO setups never reach those paths
because `can_run_tbo` requires `enable_two_batch_overlap=True`. Happy
to extend in a follow-up if there's interest.
Accuracy Tests
DeepSeek-Coder-V2-Lite-Instruct, `--tp 2 --enable-dp-attention --dp 2 --enable-mixed-chunk --chunked-prefill-size 256`,
gsm8k 200 examples / 128 threads:
```
gsm8k_score = 0.795 (threshold: 0.6)
output_throughput = 1138 token/s
```
Mixed batches confirmed in scheduler logs (`Prefill batch ... #running-req: 36`).
8/8 outputs on a separate hand-crafted long+short prompt mix exactly match
a `--enable-dp-attention` (no mixed_chunk) baseline run, indicating
correct cross-DP token routing under MIXED.
Local run had to substitute `deepseek-coder-v2-lite` (cached) for
`lmsys/sglang-ci-dsv3-test` (gated on this box) and use
`--attention-backend triton --disable-cuda-graph` (an outdated local
`sgl-kernel` install passes an unsupported `out=` kwarg to FA3). The
registered CI test does not include those overrides — CI will run with
the canonical FA3 + cuda graph configuration.
Speed Tests and Profiling
N/A — bugfix on a previously crashing path.
Checklist