Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions python/sglang/srt/layers/attention/attention_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,24 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
LightningAttentionBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
from sglang.srt.utils import is_blackwell, is_npu, is_sm120_supported

check_environments()
if runner.hybrid_gdn_config is not None:
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_sm120_supported():
# SM12x (consumer Blackwell, e.g. RTX 5090, DGX Spark): trtllm_mha has
# no SM12x cubins so it cannot be used here; triton and flashinfer both
# support the full attention layers on SM12x.
assert runner.server_args.attention_backend in (
"triton",
"flashinfer",
), "triton or flashinfer backend are the only supported backends on SM12x GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend flashinfer to specify the backend."
else:
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@
is_npu,
is_sm90_supported,
is_sm100_supported,
is_sm120_supported,
)

_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_sm90_supported = _is_cuda and is_sm90_supported()
_is_sm100_supported = _is_cuda and is_sm100_supported()
_is_sm120_supported = _is_cuda and is_sm120_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()
_is_npu = is_npu()
Expand All @@ -92,7 +94,7 @@ def apply_flashinfer_allreduce_fusion(batch_size: int):
return (
# NOTE: flashinfer 0.6.1 caused performance regression on sm100 for allreduce fusion
# Ref: https://github.com/sgl-project/sglang/issues/17237
(_is_sm90_supported or _is_sm100_supported)
(_is_sm90_supported or _is_sm100_supported or _is_sm120_supported)
and _is_flashinfer_available
and batch_size > 0
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
Expand Down
14 changes: 8 additions & 6 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def _handle_model_specific_adjustments(self):
"Qwen3MoeForCausalLM",
"KimiK25ForConditionalGeneration",
]
and (is_sm90_supported() or is_sm100_supported())
and (is_sm90_supported() or is_sm100_supported() or is_sm120_supported())
and not self.enable_dp_attention
and self.nnodes == 1
and not is_h20_device
Expand All @@ -1704,7 +1704,7 @@ def _handle_mamba_radix_cache(
sm100_default_attention_backend: str = None,
):
if (
is_sm100_supported()
(is_sm100_supported() or is_sm120_supported())
and self.attention_backend is None
and sm100_default_attention_backend is not None
):
Expand Down Expand Up @@ -1820,7 +1820,7 @@ def _handle_attention_backend_compatibility(self):
# MLA architecture
if is_hopper_with_cuda_12_3():
self.attention_backend = "fa3"
elif is_sm100_supported():
elif is_sm100_supported() or is_sm120_supported():
self.attention_backend = "flashinfer"
elif is_hip():
head_num = model_config.get_num_kv_heads(self.tp_size)
Expand Down Expand Up @@ -1875,9 +1875,9 @@ def _handle_attention_backend_compatibility(self):
self.attention_backend == "trtllm_mla"
or self.decode_attention_backend == "trtllm_mla"
):
if not is_blackwell_supported():
if not is_sm100_supported():
raise ValueError(
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100/SM12x). Please use a different backend."
"TRTLLM MLA backend is only supported on SM100 Blackwell GPUs. Please use a different backend."
)

if self.page_size not in [32, 64]:
Expand Down Expand Up @@ -2308,7 +2308,9 @@ def _handle_hicache(self):
)
else:
self.decode_attention_backend = (
"flashinfer" if is_sm100_supported() else "triton"
"flashinfer"
if is_sm100_supported() or is_sm120_supported()
else "triton"
)
else:
# If user explicitly requested FA3 decode, fall back to direct IO.
Expand Down
Loading