fix: enable flashinfer backends and allreduce fusion for SM12x Blackwell GPUs#18758
fix: enable flashinfer backends and allreduce fusion for SM12x Blackwell GPUs#18758blake-snc wants to merge 3 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
b8zhong
left a comment
There was a problem hiding this comment.
trtllm-gen doesn't support other than SM100/103, can you show proof if it actually works?
|
@b8zhong You're right to question this. After deeper investigation:
I'm revising this PR to:
Also closed #18748 and #18754 — you were correct that SM120 lacks tcgen05/TMEM. Appreciate the thorough review. From Second Nature Computing — testing on DGX Spark (SM121a, CUDA 13.0). |
6ca65fb to
ab79ae0
Compare
|
@b8zhong Rebased onto latest main and here's the proof that flashinfer works on SM120 (NVIDIA GB10 / DGX Spark, SM12.1): This PR routes SM12x to Why each change is correct:
Tested on DGX Spark hardware at Second Nature Computing. |
|
@b8zhong Just wanted to follow up with the upstream flashinfer work that supports this PR. We've been working on proper SM120 attention kernel paths in flashinfer directly:
All of this has been validated on DGX Spark (SM 12.1, CUDA 13.0). The key point is that this PR routes SM12x to Happy to provide any additional validation results if that would be helpful. |
|
Hi, update from Second Nature Computing: we have now validated the full end-to-end path on DGX Spark. We ran SGLang serving Qwen2.5-0.5B-Instruct on SM121a (GB10) with the flashinfer attention backend. With our flashinfer PRs applied (flashinfer-ai/flashinfer#2559, flashinfer-ai/flashinfer#2560, flashinfer-ai/flashinfer#2561), the backend works correctly. The server started, served requests, and produced correct outputs with no errors. Two additional requirements came up during validation:
These are outside the scope of this PR but worth noting for anyone else bringing up SM120-family support. Let us know if there is anything else needed to move this forward. |
SM12x (consumer Blackwell / DGX Spark) was excluded from optimal attention and MOE backend selection because server_args.py checked is_sm100_supported() which only matches SM100 (major=10). SM12x (major=12) fell through to suboptimal triton/auto backends. Added is_sm120_supported() alongside is_sm100_supported() at all backend selection sites, following the existing pattern at line 2678. Affected models: - DeepSeek V3: now gets trtllm_mla (was triton) - GPT-OSS: now gets trtllm_mha (was triton) - Llama4: now gets trtllm_mha + flashinfer_trtllm MOE (was triton) - Qwen3/3.5 MoE: now gets flashinfer_trtllm (was auto) - Glm4 MoE: now gets flashinfer_trtllm (was auto) - Olmo2: now gets trtllm_mha (was triton) Also fixed: - communicator.py: allreduce fusion now includes SM12x - flashinfer_mla_backend.py: MLA backend uses cutlass on SM12x - TRTLLM MHA validation gate uses is_blackwell_supported() - Log messages updated from "sm100" to "Blackwell" Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
SM120 does not have trtllm-gen FMHA cubins (only SM100/SM103 exist). While decode works via xqa fallback, prefill always uses trtllm-gen with no fallback, causing failures on SM120. Reverts SM120 routing to trtllm_mha, trtllm_mla, and flashinfer_trtllm MoE backends. SM120 correctly falls through to flashinfer or triton. Keeps SM120 support for: - flashinfer MLA attention (JIT compiled, works on SM120) - flashinfer allreduce fusion (JIT compiled, works on SM120) - flashinfer HiCache decode (JIT compiled, works on SM120) - _handle_mamba_radix_cache (routes to flashinfer/triton, not trtllm) Also reverts CUTLASS MLA fmha_backend for SM120 since cutlass_mla checks sm_version == 100. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ab79ae0 to
fbca440
Compare
|
@b8zhong Friendly follow-up — the upstream flashinfer work this PR depends on has progressed:
As shown in our validation above, this PR routes SM12x to Also rebased onto latest main and fixed the trtllm_mla error message to correctly say SM100-only (not SM12x). Happy to address any remaining concerns. |
The GDN hybrid model assertion assumed all Blackwell GPUs need triton or trtllm_mha for the full attention layers. This is true for SM100 (datacenter Blackwell) but not SM12x (consumer Blackwell, e.g. RTX 5090, DGX Spark GB10). On SM12x: - trtllm_mha has no SM12x cubins and cannot be used - triton and flashinfer both work for the full attention layers - GDNAttnBackend uses Triton kernels regardless of the attention backend Split the assertion by architecture: SM12x allows triton or flashinfer, SM10x/SM11x keeps the existing triton or trtllm_mha requirement. Validated on SM121a (DGX Spark GB10): is_blackwell()=True, is_sm120_supported()=True; flashinfer correctly passes the assertion, trtllm_mha correctly blocked. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Added commit 655f631 to fix the hybrid GDN model assertion for SM12x. The assertion at The fix splits the check by architecture: SM12x allows triton or flashinfer (trtllm_mha excluded since no SM12x cubins exist), SM10x/SM11x keeps the original triton or trtllm_mha requirement unchanged. Validated on SM121a: assertion logic confirmed correct via |
|
Correction to previous comment — the assertion logic was verified on SM121a hardware (is_blackwell=True, is_sm120_supported=True confirms the right branch is taken), but end-to-end validation with an actual GDN hybrid model (Qwen3.5-35B-A3B) has not been done yet. Downloading the model now to complete validation before this is considered ready. |
|
Update: Pre-validation logic confirmed on SM121a (full model validation in progress) While Qwen3.5-35B-A3B (72 GB, 14 shards) finishes downloading, I've verified the complete code path: 1. Qwen3.5-35B-A3B loads as 2. 3. Old assertion logic — would fire and block 4. New assertion logic (this PR) — passes: Full end-to-end server validation (boot + inference) will follow once the model finishes downloading (~50 minutes remaining). |
Validation Results — SM121a (DGX Spark GB10)Hardware: NVIDIA GB10, SM121a (compute capability 12.1) Before fix (expected failure)The old assertion triggered on any After fix — server boot logNo Test inferencecurl -s -X POST http://127.0.0.1:30000/generate \
-H "Content-Type: application/json" \
-d '{"text": "What is 2+2?", "sampling_params": {"max_new_tokens": 32, "temperature": 0}}'Response: {
"text": "\n\n2+2 equals 4. This is a basic arithmetic operation where you add two numbers together.\n",
"meta_info": {
"prompt_tokens": 7,
"completion_tokens": 32,
"e2e_latency": 2.12
}
}Launch command usedFLASHINFER_DISABLE_VERSION_CHECK=1 \
TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
SGLANG_DISABLE_CUDNN_CHECK=1 \
python3 -m sglang.launch_server \
--model-path /home/saikambampati/models/Qwen3.5-35B-A3B \
--attention-backend flashinfer \
--disable-cuda-graph \
--host 127.0.0.1 --port 30000( Validation: PASS — Qwen3.5-35B-A3B boots and serves inference with |
|
Community context for reviewers: This was reported in the SGLang Discord today when Qwen3.5-35B-A3B dropped. A user hit the exact assertion this PR fixes on SM12x hardware, and the workaround was to switch to vLLM entirely:
SM12x users on DGX Spark / RTX 5090 currently have to fall back to vLLM to run this model. This PR fixes that. We've now validated the full fix on SM121a hardware — see the validation comment above. The changes are small and targeted ( |
Summary
Enable SM12x Blackwell GPUs (DGX Spark GB10, RTX 5090/5080) to use flashinfer-based backends instead of falling through to triton for all attention/compute paths. Also fixes a Blackwell assertion in the attention registry that blocked hybrid GDN models (Qwen3.5-MoE, Qwen3Next, etc.) from using the flashinfer backend on SM12x.
Important: This PR intentionally does NOT route SM12x to `trtllm_mha`, `trtllm_mla`, or `flashinfer_trtllm` MoE backends — SM120 lacks `tcgen05` instructions required by trtllm-gen FMHA kernels (only SM100/SM103 cubins exist). Thanks to @b8zhong for catching this in #18748.
Changes
`attention_registry.py`:
`server_args.py`:
`communicator.py`:
What this does NOT change
Validation — SM121a (DGX Spark GB10)
Hardware: NVIDIA GB10, compute capability 12.1
Model: Qwen3.5-35B-A3B (`Qwen3_5MoeConfig` — a `hybrid_gdn_config` model)
Backend: `--attention-backend flashinfer`
Before fix
```
AssertionError: triton or trtllm_mha backend are the only supported backends
on Blackwell GPUs for hybrid GDN models, got: flashinfer
```
After fix — server boot
```
Load weight end. type=Qwen3_5MoeForConditionalGeneration, dtype=torch.bfloat16,
mem usage=61.92 GB
Mamba Cache is allocated. conv_state=0.26 GB, ssm_state=11.13 GB
KV Cache is allocated. #tokens=661722, K size=6.31 GB, V size=6.31 GB
Using hybrid linear attention backend for hybrid GDN models.
Application startup complete.
The server is fired up and ready to roll!
```
Test inference
```bash
curl -X POST http://127.0.0.1:30000/generate
-d '{"text": "What is 2+2?", "sampling_params": {"max_new_tokens": 32, "temperature": 0}}'
Response: "2+2 equals 4. This is a basic arithmetic operation..."
```
Test plan
Tested on DGX Spark (SM121a, CUDA 13.0, aarch64) provided by Second Nature Computing.
🤖 Generated with Claude Code