diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 25ffc3f0b412..9c2274e46e08 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -195,15 +195,24 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac LightningAttentionBackend, Mamba2AttnBackend, ) - from sglang.srt.utils import is_blackwell, is_npu + from sglang.srt.utils import is_blackwell, is_npu, is_sm120_supported check_environments() if runner.hybrid_gdn_config is not None: if is_blackwell(): - assert ( - runner.server_args.attention_backend == "triton" - or runner.server_args.attention_backend == "trtllm_mha" - ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend." + if is_sm120_supported(): + # SM12x (consumer Blackwell, e.g. RTX 5090, DGX Spark): trtllm_mha has + # no SM12x cubins so it cannot be used here; triton and flashinfer both + # support the full attention layers on SM12x. + assert runner.server_args.attention_backend in ( + "triton", + "flashinfer", + ), "triton or flashinfer backend are the only supported backends on SM12x GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend flashinfer to specify the backend." + else: + assert ( + runner.server_args.attention_backend == "triton" + or runner.server_args.attention_backend == "trtllm_mha" + ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend." if is_npu(): assert ( runner.server_args.attention_backend == "ascend" diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 64f2cf66242e..42c7efaa070b 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -65,12 +65,14 @@ is_npu, is_sm90_supported, is_sm100_supported, + is_sm120_supported, ) _is_cuda = is_cuda() _is_flashinfer_available = is_flashinfer_available() _is_sm90_supported = _is_cuda and is_sm90_supported() _is_sm100_supported = _is_cuda and is_sm100_supported() +_is_sm120_supported = _is_cuda and is_sm120_supported() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _is_gfx95_supported = is_gfx95_supported() _is_npu = is_npu() @@ -92,7 +94,7 @@ def apply_flashinfer_allreduce_fusion(batch_size: int): return ( # NOTE: flashinfer 0.6.1 caused performance regression on sm100 for allreduce fusion # Ref: https://github.com/sgl-project/sglang/issues/17237 - (_is_sm90_supported or _is_sm100_supported) + (_is_sm90_supported or _is_sm100_supported or _is_sm120_supported) and _is_flashinfer_available and batch_size > 0 and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 952d5356376b..dafd4865d674 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1688,7 +1688,7 @@ def _handle_model_specific_adjustments(self): "Qwen3MoeForCausalLM", "KimiK25ForConditionalGeneration", ] - and (is_sm90_supported() or is_sm100_supported()) + and (is_sm90_supported() or is_sm100_supported() or is_sm120_supported()) and not self.enable_dp_attention and self.nnodes == 1 and not is_h20_device @@ -1704,7 +1704,7 @@ def _handle_mamba_radix_cache( sm100_default_attention_backend: str = None, ): if ( - is_sm100_supported() + (is_sm100_supported() or is_sm120_supported()) and self.attention_backend is None and sm100_default_attention_backend is not None ): @@ -1820,7 +1820,7 @@ def _handle_attention_backend_compatibility(self): # MLA architecture if is_hopper_with_cuda_12_3(): self.attention_backend = "fa3" - elif is_sm100_supported(): + elif is_sm100_supported() or is_sm120_supported(): self.attention_backend = "flashinfer" elif is_hip(): head_num = model_config.get_num_kv_heads(self.tp_size) @@ -1875,9 +1875,9 @@ def _handle_attention_backend_compatibility(self): self.attention_backend == "trtllm_mla" or self.decode_attention_backend == "trtllm_mla" ): - if not is_blackwell_supported(): + if not is_sm100_supported(): raise ValueError( - "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100/SM12x). Please use a different backend." + "TRTLLM MLA backend is only supported on SM100 Blackwell GPUs. Please use a different backend." ) if self.page_size not in [32, 64]: @@ -2308,7 +2308,9 @@ def _handle_hicache(self): ) else: self.decode_attention_backend = ( - "flashinfer" if is_sm100_supported() else "triton" + "flashinfer" + if is_sm100_supported() or is_sm120_supported() + else "triton" ) else: # If user explicitly requested FA3 decode, fall back to direct IO.