[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 #30
[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 #30wtr0504 wants to merge 7 commits intoSandAI-org:mainfrom
Conversation
| from ..pass_base import MagiInductorPass | ||
|
|
||
|
|
||
| class RemoveUselessOpsPass(MagiInductorPass): |
There was a problem hiding this comment.
RemoveUselessOpsPass -> EliminateIdentityViewCastPass
Note: also change the file name
There was a problem hiding this comment.
Move this file to the utils directory and update __init__.py
|
|
||
| ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf" | ||
|
|
||
| # CUTLASS — source is always cloned (the magi_compiler EVT-fusion path |
There was a problem hiding this comment.
Treat cutlass as third-party and provide install cmd cause users may install magi_compiler without docker.
Update commands in readme.md pls~
| # we lower fused chains to a CUTLASS Sm80EVT kernel. Toggled via | ||
| # PassConfig.enable_mm_epilogue_fusion (default True). The device | ||
| # check is independent — even with the flag on, non-sm_120 hosts | ||
| # don't register the pass since its FX walker would just no-op. |
There was a problem hiding this comment.
It seems better to support epilogue fusion for other arches, even if it achieves no performance gain.
Cause this fusion can also save memory.
| # Propagate but with 128B-aligned stride matching what the | ||
| # CUDA impl actually returns. | ||
| new_val = val_last.new_empty_strided( | ||
| val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1) |
There was a problem hiding this comment.
_aligned_n_stride not defined
| f.write(src) | ||
| os.replace(tmp_path, src_path) | ||
|
|
||
| cutlass_root = _cutlass_root() |
There was a problem hiding this comment.
Find a better way to find the cutlass_root dir~
| if target in _TYPE_CONV_OPS: | ||
| node_to_ir[curr] = node_to_ir[curr.args[0]] | ||
| walk_seen.append(curr) | ||
| last_node = curr | ||
| last_ir = node_to_ir[curr] | ||
| curr = curr.next |
There was a problem hiding this comment.
If we have a compute chain like this:
matmul -> to_1(fp32) -> func1 -> to_2(fp8) -> func2
What happens to type conversion ops like to_1 and to_2?
🗂️ PR Category
📝 Description
Implements a CUTLASS-based matmul + elementwise fusion path on RTX 5090(Blackwell consumer, sm_120). On the GAGA2 swiglu7 site (M=7697, N=27304, K=5120) gives +11% over torch.compile and +15% over eager bf16 for cp1 prefill.
Pipeline
Caching layers
Wired into post_grad_pass_manager.configure() under torch.cuda.get_device_capability()[0] >= 12.
Tests
tests/feature_tests/test_matmul_epilogue_fusion.py rewritten — 25 tests covering:
Each positive test instruments MatmulEvtEpilogueFusionPass.call to assert the expected
kindandout_dtype_idwere actually emitted — numerical equivalence alone would silently pass when fusion is skipped (both paths fall back to cuBLAS).Limitations
kErrorMisalignedOperand.Performance
Notes
MAGI_CUTLASS_ROOTEVT / swiglu7 JIT builds resolve CUTLASS headers via
torch.utils.cpp_extension.load(..., extra_include_paths=...). SetMAGI_CUTLASS_ROOTto the top-level CUTLASS checkout (the directory that containsinclude/, e.g.tools/util/include/). If unset, the code defaults to/opt/cutlass.