Skip to content

CUDA -> Helion exploration for token group padding/unpadding kernels #4062

@danielvegamyhre

Description

@danielvegamyhre

Context

When we do a grouped GEMM for routed experts in a MoE, the shape is (total_M, K) @ (K, N). where total_M = sum of all token group sizes. Token groups are the tokens being routed to each expert, and they are variable in size.
In the forward pass, the sizes of these token groups doesn't matter. however, in the backward pass, for the gradient of the weight computation, it is: dW = dO^T @ inputs. <-- the shapes here are (N, total_M) @ (total_M,K). As you can see, when we do MXFP8 quantization on these shapes, the 1x32 scaling for the LHS operand will be going along the total_M dim, and the 32x1 scaling for RHS will be going along total_M dim as well.

We can't have scaling factors cross group boundaries, because these are logically independent gemms, and this would not be a proper impl of the accepted definition of MXFP8 quantization to share scaling factors across tensors, AND from an impl perspective would be incredibly weird.

Therefore, we need to pad the token group sizes to the next multiple of 32.

We have decent CUDA kernels for this (#3998) but would like to move toward jit compiled python DSLs to avoid C++ ABI compatibility issues with torch, and complexity in shipping binaries.

Torch reference implementation

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions