Skip to content

[MXFP8 MoE] mx_block_rearrange_2d_M_groups_cuda fails for models with > 32 experts #4163

@yzhautouskay

Description

@yzhautouskay

Problem

mx_block_rearrange_2d_M_groups_cuda hard-checks num_groups <= 32 (mxfp8_extension.cpp:200). In MoE, num_groups = num_experts (offsets tensor has one entry per expert). Models with more experts crash immediately:

 RuntimeError: num_groups must be <= 32

There used to be use_cuda_kernel_for_blocked_layout flag, that allowed falling back to the Triton kernel, but this was removed and in the v0.16 KernelPreference.AUTO provides no per-op fallback — it's all-CUDA or all-emulated, so there's no way to selectively use the Triton kernel for this one op.

Suggested fix

It would be great to support auto-fallback to triton_mx_block_rearrange_2d_M_groups (which handles arbitrary group counts) when num_groups > 32, instead of hard-failing.

Environment

torchao 0.16.0+git3c1065c
PyTorch 2.10.0a0+a36e1d39eb.nv26.1
SM100 (GB200)

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions