Skip to content

Comments

fix: enable flashinfer backends and allreduce fusion for SM12x Blackwell GPUs#18758

Open
blake-snc wants to merge 3 commits intosgl-project:mainfrom
blake-snc:fix/sm12x-backend-selection
Open

fix: enable flashinfer backends and allreduce fusion for SM12x Blackwell GPUs#18758
blake-snc wants to merge 3 commits intosgl-project:mainfrom
blake-snc:fix/sm12x-backend-selection

Conversation

@blake-snc
Copy link
Contributor

@blake-snc blake-snc commented Feb 13, 2026

Summary

Enable SM12x Blackwell GPUs (DGX Spark GB10, RTX 5090/5080) to use flashinfer-based backends instead of falling through to triton for all attention/compute paths. Also fixes a Blackwell assertion in the attention registry that blocked hybrid GDN models (Qwen3.5-MoE, Qwen3Next, etc.) from using the flashinfer backend on SM12x.

Important: This PR intentionally does NOT route SM12x to `trtllm_mha`, `trtllm_mla`, or `flashinfer_trtllm` MoE backends — SM120 lacks `tcgen05` instructions required by trtllm-gen FMHA kernels (only SM100/SM103 cubins exist). Thanks to @b8zhong for catching this in #18748.

Changes

`attention_registry.py`:

  • Split the Blackwell GDN assertion to allow `flashinfer` on SM12x. Previously the assertion required `triton` or `trtllm_mha` for all Blackwell GPUs on hybrid GDN models; SM12x (which doesn't support trtllm) can now use `flashinfer`.

`server_args.py`:

  • Enable flashinfer allreduce fusion for SM12x (alongside SM90/SM100)
  • Route SM12x MLA models to flashinfer attention (instead of triton fallback)
  • Route SM12x HiCache decode to flashinfer (instead of triton)
  • Include SM12x in `_handle_mamba_radix_cache` defaults (routes to flashinfer or triton per-model, never trtllm)

`communicator.py`:

  • Add `is_sm120_supported` to allreduce fusion eligibility check

What this does NOT change

  • No trtllm_mha/trtllm_mla routing for SM120 (prefill uses trtllm-gen which has no SM120 cubins)
  • No flashinfer_trtllm MoE runner for SM120 (unvalidated)
  • No CUTLASS MLA fmha_backend for SM120 (cutlass_mla checks `sm_version == 100`)

Validation — SM121a (DGX Spark GB10)

Hardware: NVIDIA GB10, compute capability 12.1
Model: Qwen3.5-35B-A3B (`Qwen3_5MoeConfig` — a `hybrid_gdn_config` model)
Backend: `--attention-backend flashinfer`

Before fix

```
AssertionError: triton or trtllm_mha backend are the only supported backends
on Blackwell GPUs for hybrid GDN models, got: flashinfer
```

After fix — server boot

```
Load weight end. type=Qwen3_5MoeForConditionalGeneration, dtype=torch.bfloat16,
mem usage=61.92 GB
Mamba Cache is allocated. conv_state=0.26 GB, ssm_state=11.13 GB
KV Cache is allocated. #tokens=661722, K size=6.31 GB, V size=6.31 GB
Using hybrid linear attention backend for hybrid GDN models.
Application startup complete.
The server is fired up and ready to roll!
```

Test inference

```bash
curl -X POST http://127.0.0.1:30000/generate
-d '{"text": "What is 2+2?", "sampling_params": {"max_new_tokens": 32, "temperature": 0}}'

Response: "2+2 equals 4. This is a basic arithmetic operation..."

```

Test plan

  • Validated on DGX Spark (SM121a, CUDA 13.0): `is_sm100_supported()=False`, `is_sm120_supported()=True`
  • Verified all trtllm paths SKIPPED on SM120
  • Verified all flashinfer paths ENABLED on SM120
  • End-to-end model boot + inference: Qwen3.5-35B-A3B on SM121a with flashinfer backend
  • Pre-commit checks pass

Tested on DGX Spark (SM121a, CUDA 13.0, aarch64) provided by Second Nature Computing.

🤖 Generated with Claude Code

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trtllm-gen doesn't support other than SM100/103, can you show proof if it actually works?

@blake-snc
Copy link
Contributor Author

@b8zhong You're right to question this. After deeper investigation:

  1. trtllm-gen cubins: Only SM100/SM103 cubins exist in flashinfer (fmhaSm100aKernel, fmhaSm100fKernel, fmhaSm103aKernel). No SM120 cubins.
  2. Decode: flashinfer auto-routes SM120 to xqa backend ("trtllm-gen" if major == 10 else "xqa"), so decode works.
  3. Prefill: trtllm_batch_context_with_kv_cache and trtllm_ragged_attention_deepseek always use trtllm-gen with no xqa fallback — this would FAIL on SM120.

I'm revising this PR to:

  • Route SM120 to flashinfer instead of trtllm_mha/trtllm_mla (since prefill has no SM120 path)
  • Keep the allreduce fusion and MLA fmha_backend changes (these use flashinfer internally, not trtllm-gen directly)
  • Keep the general attention backend routing to flashinfer (line 1803, 2245)

Also closed #18748 and #18754 — you were correct that SM120 lacks tcgen05/TMEM. Appreciate the thorough review.

From Second Nature Computing — testing on DGX Spark (SM121a, CUDA 13.0).

@blake-snc blake-snc changed the title fix: enable optimal backend selection for SM12x Blackwell GPUs fix: enable flashinfer backends and allreduce fusion for SM12x Blackwell GPUs Feb 13, 2026
@blake-snc blake-snc force-pushed the fix/sm12x-backend-selection branch from 6ca65fb to ab79ae0 Compare February 13, 2026 20:58
@blake-snc
Copy link
Contributor Author

@b8zhong Rebased onto latest main and here's the proof that flashinfer works on SM120 (NVIDIA GB10 / DGX Spark, SM12.1):

This PR routes SM12x to flashinfer (not trtllm) because trtllm-gen only has SM100/SM103 cubins. Here's validation of every flashinfer path SGLang would use:

Device: NVIDIA GB10 (SM 12.1)

Test 1: FA2 single prefill (MHA, causal)
  seqlen=1024, heads=32, hd=128
  Max diff vs PyTorch SDPA: 0.015625 — PASS

Test 2: FA2 single prefill (GQA, non-causal)
  seqlen=512, q_heads=32, kv_heads=8, hd=128
  Max diff vs PyTorch SDPA: 0.001953 — PASS

Test 3: XQA decode (GQA, hd=128)
  q_heads=32, kv_heads=8, kv_len=1024 — PASS

Test 4: Single decode (MHA, hd=128)
  heads=16, kv_len=512 — PASS

Test 5: XQA decode (hd=64, Llama-style)
  q_heads=32, kv_heads=8, kv_len=2048 — PASS

5/5 tests passed

Why each change is correct:

  1. Allreduce fusion (communicator.py): Uses flashinfer internally, which JIT-compiles for SM120
  2. MLA backend (server_args.py:1801): Routes to flashinfer FA2 MLA (not trtllm_mla). flashinfer's determine_mla_backend() returns "fa2" for SM120 which works.
  3. trtllm_mla validation (server_args.py:1856): Changed from is_blackwell_supported() to is_sm100_supported() — correctly excludes SM120 from trtllm_mla
  4. HiCache decode (server_args.py:2245): Routes to flashinfer (XQA decode works on SM120, confirmed above)

Tested on DGX Spark hardware at Second Nature Computing.

@blake-snc
Copy link
Contributor Author

@b8zhong Just wanted to follow up with the upstream flashinfer work that supports this PR. We've been working on proper SM120 attention kernel paths in flashinfer directly:

All of this has been validated on DGX Spark (SM 12.1, CUDA 13.0). The key point is that this PR routes SM12x to flashinfer instead of trtllm, and the flashinfer PRs above ensure those paths actually work correctly on SM120 hardware.

Happy to provide any additional validation results if that would be helpful.

@blake-snc
Copy link
Contributor Author

Hi, update from Second Nature Computing: we have now validated the full end-to-end path on DGX Spark.

We ran SGLang serving Qwen2.5-0.5B-Instruct on SM121a (GB10) with the flashinfer attention backend. With our flashinfer PRs applied (flashinfer-ai/flashinfer#2559, flashinfer-ai/flashinfer#2560, flashinfer-ai/flashinfer#2561), the backend works correctly. The server started, served requests, and produced correct outputs with no errors.

Two additional requirements came up during validation:

  1. An sgl-kernel cu130 wheel is needed so that SM121a cubins are available.
  2. TRITON_PTXAS_PATH must point to the system ptxas from CUDA 13.0, otherwise Triton compilation fails.

These are outside the scope of this PR but worth noting for anyone else bringing up SM120-family support.

Let us know if there is anything else needed to move this forward.

blake-snc and others added 2 commits February 20, 2026 23:09
SM12x (consumer Blackwell / DGX Spark) was excluded from optimal
attention and MOE backend selection because server_args.py checked
is_sm100_supported() which only matches SM100 (major=10). SM12x
(major=12) fell through to suboptimal triton/auto backends.

Added is_sm120_supported() alongside is_sm100_supported() at all
backend selection sites, following the existing pattern at line 2678.

Affected models:
- DeepSeek V3: now gets trtllm_mla (was triton)
- GPT-OSS: now gets trtllm_mha (was triton)
- Llama4: now gets trtllm_mha + flashinfer_trtllm MOE (was triton)
- Qwen3/3.5 MoE: now gets flashinfer_trtllm (was auto)
- Glm4 MoE: now gets flashinfer_trtllm (was auto)
- Olmo2: now gets trtllm_mha (was triton)

Also fixed:
- communicator.py: allreduce fusion now includes SM12x
- flashinfer_mla_backend.py: MLA backend uses cutlass on SM12x
- TRTLLM MHA validation gate uses is_blackwell_supported()
- Log messages updated from "sm100" to "Blackwell"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
SM120 does not have trtllm-gen FMHA cubins (only SM100/SM103 exist).
While decode works via xqa fallback, prefill always uses trtllm-gen
with no fallback, causing failures on SM120.

Reverts SM120 routing to trtllm_mha, trtllm_mla, and flashinfer_trtllm
MoE backends. SM120 correctly falls through to flashinfer or triton.

Keeps SM120 support for:
- flashinfer MLA attention (JIT compiled, works on SM120)
- flashinfer allreduce fusion (JIT compiled, works on SM120)
- flashinfer HiCache decode (JIT compiled, works on SM120)
- _handle_mamba_radix_cache (routes to flashinfer/triton, not trtllm)

Also reverts CUTLASS MLA fmha_backend for SM120 since cutlass_mla
checks sm_version == 100.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc force-pushed the fix/sm12x-backend-selection branch from ab79ae0 to fbca440 Compare February 21, 2026 07:10
@blake-snc
Copy link
Contributor Author

@b8zhong Friendly follow-up — the upstream flashinfer work this PR depends on has progressed:

As shown in our validation above, this PR routes SM12x to flashinfer (not trtllm-gen), and these upstream fixes ensure the flashinfer paths work correctly on SM120.

Also rebased onto latest main and fixed the trtllm_mla error message to correctly say SM100-only (not SM12x). Happy to address any remaining concerns.

The GDN hybrid model assertion assumed all Blackwell GPUs need triton or
trtllm_mha for the full attention layers. This is true for SM100
(datacenter Blackwell) but not SM12x (consumer Blackwell, e.g. RTX 5090,
DGX Spark GB10).

On SM12x:
- trtllm_mha has no SM12x cubins and cannot be used
- triton and flashinfer both work for the full attention layers
- GDNAttnBackend uses Triton kernels regardless of the attention backend

Split the assertion by architecture: SM12x allows triton or flashinfer,
SM10x/SM11x keeps the existing triton or trtllm_mha requirement.

Validated on SM121a (DGX Spark GB10): is_blackwell()=True,
is_sm120_supported()=True; flashinfer correctly passes the assertion,
trtllm_mha correctly blocked.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc
Copy link
Contributor Author

Added commit 655f631 to fix the hybrid GDN model assertion for SM12x.

The assertion at attention_registry.py:202 blocks flashinfer on all Blackwell GPUs for hybrid GDN models (e.g. Qwen3.5-35B-A3B), requiring triton or trtllm_mha. On SM12x this causes an immediate AssertionError when using --attention-backend flashinfer since trtllm_mha has no SM12x cubins and flashinfer wasn't in the allowed list.

The fix splits the check by architecture: SM12x allows triton or flashinfer (trtllm_mha excluded since no SM12x cubins exist), SM10x/SM11x keeps the original triton or trtllm_mha requirement unchanged.

Validated on SM121a: assertion logic confirmed correct via is_blackwell()=True, is_sm120_supported()=True. Full model test (Qwen3.5-35B-A3B) pending download.

@blake-snc
Copy link
Contributor Author

Correction to previous comment — the assertion logic was verified on SM121a hardware (is_blackwell=True, is_sm120_supported=True confirms the right branch is taken), but end-to-end validation with an actual GDN hybrid model (Qwen3.5-35B-A3B) has not been done yet. Downloading the model now to complete validation before this is considered ready.

@blake-snc
Copy link
Contributor Author

Update: Pre-validation logic confirmed on SM121a (full model validation in progress)

While Qwen3.5-35B-A3B (72 GB, 14 shards) finishes downloading, I've verified the complete code path:

1. Qwen3.5-35B-A3B loads as Qwen3_5MoeConfig (SGLang custom config for model_type: qwen3_5_moe), which triggers hybrid_gdn_config in the model runner.

2. is_blackwell()=True, is_sm120_supported()=True on the DGX Spark (SM121a, compute cap 12.1).

3. Old assertion logic — would fire and block flashinfer:

OLD code with attention_backend=flashinfer:
  AssertionError: triton or trtllm_mha backend are the only supported backends
  on Blackwell GPUs for hybrid GDN models, got: flashinfer

4. New assertion logic (this PR) — passes:

NEW code with attention_backend=flashinfer:
  PASS (assertion allows flashinfer on SM12x)

Full end-to-end server validation (boot + inference) will follow once the model finishes downloading (~50 minutes remaining).

@blake-snc
Copy link
Contributor Author

Validation Results — SM121a (DGX Spark GB10)

Hardware: NVIDIA GB10, SM121a (compute capability 12.1)
Model: Qwen3.5-35B-A3B (model_type: qwen3_5_moeQwen3_5MoeConfig)
Backend: --attention-backend flashinfer
Commit: 655f631

Before fix (expected failure)

The old assertion triggered on any hybrid_gdn_config model (which includes Qwen3_5MoeConfig) when using the flashinfer backend on Blackwell:

AssertionError: triton or trtllm_mha backend are the only supported backends
on Blackwell GPUs for hybrid GDN models, got: flashinfer

After fix — server boot log

[2026-02-24 14:52:18] Load weight begin. avail mem=104.49 GB
[2026-02-24 14:54:03] Load weight end. elapsed=104.70 s,
                       type=Qwen3_5MoeForConditionalGeneration,
                       dtype=torch.bfloat16,
                       avail mem=42.57 GB, mem usage=61.92 GB
[2026-02-24 14:54:08] Mamba Cache is allocated.
                       max_mamba_cache_size: 189,
                       conv_state size: 0.26 GB, ssm_state size: 11.13 GB
[2026-02-24 14:54:11] KV Cache is allocated.
                       #tokens: 661722, K size: 6.31 GB, V size: 6.31 GB
[2026-02-24 14:54:13] Using hybrid linear attention backend for hybrid GDN models.
[2026-02-24 14:54:22] Application startup complete.
[2026-02-24 14:55:10] The server is fired up and ready to roll!

No AssertionError. The model loads and the hybrid linear attention backend initialises correctly.

Test inference

curl -s -X POST http://127.0.0.1:30000/generate \
  -H "Content-Type: application/json" \
  -d '{"text": "What is 2+2?", "sampling_params": {"max_new_tokens": 32, "temperature": 0}}'

Response:

{
  "text": "\n\n2+2 equals 4. This is a basic arithmetic operation where you add two numbers together.\n",
  "meta_info": {
    "prompt_tokens": 7,
    "completion_tokens": 32,
    "e2e_latency": 2.12
  }
}

Launch command used

FLASHINFER_DISABLE_VERSION_CHECK=1 \
TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas \
SGLANG_DISABLE_CUDNN_CHECK=1 \
python3 -m sglang.launch_server \
    --model-path /home/saikambampati/models/Qwen3.5-35B-A3B \
    --attention-backend flashinfer \
    --disable-cuda-graph \
    --host 127.0.0.1 --port 30000

(SGLANG_DISABLE_CUDNN_CHECK=1 is needed on this machine due to a PyTorch 2.9.1 + CuDNN 9.13 compatibility check that rejects the version even though inference works fine.)

Validation: PASS — Qwen3.5-35B-A3B boots and serves inference with --attention-backend flashinfer on SM121a without hitting the Blackwell GDN assertion.

@blake-snc
Copy link
Contributor Author

Community context for reviewers: This was reported in the SGLang Discord today when Qwen3.5-35B-A3B dropped. A user hit the exact assertion this PR fixes on SM12x hardware, and the workaround was to switch to vLLM entirely:

shakh: As usual, we are having our classic sm12x drama - AssertionError: triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend.

eugr: just ran 35B-A3B on a fresh build: vllm serve Qwen/Qwen3.5-35B-A3B --gpu-memory-utilization 0.7 --port 8888 --load-format fastsafetensors

SM12x users on DGX Spark / RTX 5090 currently have to fall back to vLLM to run this model. This PR fixes that.

We've now validated the full fix on SM121a hardware — see the validation comment above. The changes are small and targeted (attention_registry.py: split the Blackwell assertion to allow flashinfer on SM12x, since trtllm_mha has no SM120 cubins anyway).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants