The flash expert is an accelerated primitive for Mixture-of-Experts expert evaluation when routing has already been computed. Given per-token expert indices and routing weights, it gathers the corresponding expert parameters and computes a fused “expert MLP-like” mapping:
- compute per-token and per-expert scores via a dot product with a down weight,
- apply a SiLU nonlinearity and multiply by the routing weight,
- project back to hidden space via an up weight and sum over the selected experts.
This is implemented as Triton kernels with an autograd-aware wrapper.
The primary user-facing API is the autograd-aware wrapper:
from flash_moe.ops.flash_expert import triton_flash_expert_func
expert_states = triton_flash_expert_func(
hidden_states,
down_weights,
up_weights,
indices,
routing_weights,
)hidden_states(torch.Tensor):- shape:
(num_tokens, hidden_size) - dtype: typically
torch.float16,torch.bfloat16, ortorch.float32 - device: CUDA tensor (Triton kernels run on GPU)
- shape:
down_weights(torch.Tensor):- shape:
(num_experts, hidden_size) - dtype: typically
torch.float16,torch.bfloat16, ortorch.float32 - device: CUDA tensor (Triton kernels run on GPU)
- shape:
up_weights(torch.Tensor):- shape:
(num_experts, hidden_size) - dtype: typically
torch.float16,torch.bfloat16, ortorch.float32 - device: CUDA tensor (Triton kernels run on GPU)
- shape:
indices(torch.LongTensor):- shape:
(num_tokens, num_experts_per_tok) - note: each entry is an expert id in
[0, num_experts) - device: CUDA tensor (Triton kernels run on GPU)
- shape:
routing_weights(torch.Tensor):- shape:
(num_tokens, num_experts_per_tok) - dtype: typically
torch.float16,torch.bfloat16, ortorch.float32 - device: CUDA tensor (Triton kernels run on GPU)
- shape:
expert_states(torch.Tensor):- shape:
(num_tokens, hidden_size) - dtype: matches
hidden_states.dtype - device: same as inputs
- shape:
Expert tests and benchmarks live in tests/test_expert.py. They include a PyTorch reference implementation and Triton-based implementations for forward and backward throughput.
To run the expert benchmarks on a CUDA-enabled machine:
pytest tests/test_expert.py -sYou can run individual tests with, for example:
pytest tests/test_expert.py::test_expert_forward_throughput -s
pytest tests/test_expert.py::test_expert_backward_throughput -sMake sure that:
- PyTorch is installed with CUDA support,
- Triton is installed and compatible with your CUDA/PyTorch version,
- the GPU has sufficient memory for the chosen
(num_tokens, hidden_size, num_experts, top_k)settings.