-
Notifications
You must be signed in to change notification settings - Fork 467
Description
Context
When we do a grouped GEMM for routed experts in a MoE, the shape is (total_M, K) @ (K, N). where total_M = sum of all token group sizes. Token groups are the tokens being routed to each expert, and they are variable in size.
In the forward pass, the sizes of these token groups doesn't matter. however, in the backward pass, for the gradient of the weight computation, it is: dW = dO^T @ inputs. <-- the shapes here are (N, total_M) @ (total_M,K). As you can see, when we do MXFP8 quantization on these shapes, the 1x32 scaling for the LHS operand will be going along the total_M dim, and the 32x1 scaling for RHS will be going along total_M dim as well.
We can't have scaling factors cross group boundaries, because these are logically independent gemms, and this would not be a proper impl of the accepted definition of MXFP8 quantization to share scaling factors across tensors, AND from an impl perspective would be incredibly weird.
Therefore, we need to pad the token group sizes to the next multiple of 32.
We have decent CUDA kernels for this (#3998) but would like to move toward jit compiled python DSLs to avoid C++ ABI compatibility issues with torch, and complexity in shipping binaries.