Skip to content

[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 #30

Open
wtr0504 wants to merge 7 commits intoSandAI-org:mainfrom
wtr0504:feat/matmul_epilogue
Open

[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 #30
wtr0504 wants to merge 7 commits intoSandAI-org:mainfrom
wtr0504:feat/matmul_epilogue

Conversation

@wtr0504
Copy link
Copy Markdown
Collaborator

@wtr0504 wtr0504 commented Apr 28, 2026

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Implements a CUTLASS-based matmul + elementwise fusion path on RTX 5090(Blackwell consumer, sm_120). On the GAGA2 swiglu7 site (M=7697, N=27304, K=5120) gives +11% over torch.compile and +15% over eager bf16 for cp1 prefill.

Pipeline


 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/
                                                                                                                                                   
    evt_ir.py
      Small dataclass IR for elementwise epilogue chains. 6 primitive                                                                              
      nodes (Accum / RowBroadcast / ColBroadcast / AuxLoad / Compute /                                                                             
      Store) + 14 unary, 6 binary, 9 scalar-baked op names. Canonical                                                                              
      JSON gives a deterministic cache key.                                                                                                        
                                                                                                                                                   
    evt_codegen.py                                                                                                                                 
      Renders a complete .cu source from an IR tree. Each rendered module                                                                          
      contains N per-tile EvtConfig instantiations and an EvtAutoTuneRunner                                                                        
      that times every candidate on the first call per (N, K), then                                                                                
      sticky-caches best_idx. Per-M-bucket candidate sets (small <= 256,                                                                           
      medium <= 2048, large > 2048) hand-tuned for the 5090's 100 KB SMEM                                                                          
      / 170 SMs budget.                                                                                                                            
                                                                                                                                                   
    evt_runtime.py                                                                                                                                 
      Defines torch.ops.magi_epilogue.matmul_custom_evt + fake impl. Hot                                                                           
      path collapses to one _DISPATCH_CACHE.get → pre-bound C++ callable                                                                           
      (~2 us Python overhead / call). JIT compile keyed by (IR, A_dt,                                                                              
      B_dt, b_layout, m_bucket, N, K) → distinct .cu per weight shape;                                                                             
      disk cache via cpp_extension build_directory.                                                                                                
                                                                                                                                                   
    matmul_epilogue_fusion.py                                                                                                                      
      MatmulEvtEpilogueFusionPass. Walks aten.mm consumers, greedy-builds                                                                          
      the EVT IR bottom-up. Special-cases the swiglu7 pattern (slice                                                                               
      stride-2 + dual clamps + scaled SiLU) onto the vendored DualGemm                                                                             
      kernel. Refuses fusion when intermediate fused nodes have external                                                                           
      consumers (escape detector — caught the attention->residual->RMSNorm                                                                         
      spurious fusion in profile). Up-front alignment guards (K % 8 == 0,
      generic N % 4 == 0, swiglu7 N % 8 == 0) ensure CUTLASS AlignmentC                                                                            
      constraints are always satisfiable so the runtime never has to deal                                                                          
      with mis-aligned D.                                                                                                                          
                                                                                                                                                   
    cutlass_kernels/swiglu7_epi_one_stage.cu + swiglu7_combine.h                                                                                   
      Vendored CUTLASS DualGemm kernel: two GEMMs share A in smem;                                                                                 
      intermediate (M, N/2) accumulators stay in registers; custom                                                                                 
      Swiglu7Combine epilogue writes only D. 10-candidate autotune.                                                                                
                        

Caching layers


  • .cu module cache — (IR, dtypes, layout, bucket, N, K); disk + in-process. Each unique (N, K) gets its own module so autotune state is isolated.
  • C++ runner best_idx_ — single int per module; sticky after first call (no per-shape re-tune).
  • D output buffer — single-entry greedy; evict-and-replace on shape mismatch. Allocated as a true torch.empty((M, n_out), dtype) — no padded stride / scratch route, since the FX guards guarantee n_out % AlignmentC == 0.

Wired into post_grad_pass_manager.configure() under torch.cuda.get_device_capability()[0] >= 12.

Tests


tests/feature_tests/test_matmul_epilogue_fusion.py rewritten — 25 tests covering:

  • 5 activations (parametrized: silu / sigmoid / gelu / gelu7 / relu_square) + native relu + swiglu7 special-case path.
  • Binary ops on the mm output: scalar add / sub / mul / div + 1-D RowBroadcast bias + full (M, N) AuxLoad gate.
  • out_dtype matrix: bf16-native, bf16-via-high-precision (cast→fp32→ cast back), fp32 (no final cast), bf16→fp16 cross-precision, pure fp16 path. Each asserts both the emitted op's out_dtype_id (args[6]) and the runtime tensor's dtype.
  • 5 negative cases (intermediate-escape / bare mm / K mis-aligned / generic N mis-aligned / swiglu7 N not %8 / fp32 mm) — every one asserts fused_count == 0 and the chain falls back to torch.compile.
  • IR canonical-json determinism (cache-key stability).

Each positive test instruments MatmulEvtEpilogueFusionPass.call to assert the expected kind and out_dtype_id were actually emitted — numerical equivalence alone would silently pass when fusion is skipped (both paths fall back to cuBLAS).

Limitations


  • sm_120 only. Older GPUs aren't wired up; the FX pass returns False nconditionally on cap < 12, so they continue to run unfused via cuBLAS.
  • Operand alignment constraints baked into the FX guards: K % 8 == 0 (CUTLASS AlignmentA = 8 for bf16/fp16); generic N % 4 == 0 (AlignmentC = 4); swiglu7 N % 8 == 0 (so n_out = N/2 is itself 4-aligned). Anything below these falls back to torch.compile / cuBLAS and is NOT retried with a slower CUTLASS path. AlignmentB = 8 has a small soft-spot in the evt_row layout (B-stride could be 4-aligned but not 8-aligned for N values like 12) — not exercised by any real workload but a future shape could hit kErrorMisalignedOperand.
  • Input dtypes restricted to {bf16, fp16}. fp32 mm is rejected up front. Mixed-precision (e.g. fp8 inputs) needs a CUTLASS 3.x EVT path that this PR does not include.
  • B operand layout: only contiguous (K, N) row-major or permute([1, 0])-of-contig-(N, K) (== column-major (K, N)). Any other stride pattern aborts fusion.
  • Single-stream inference only. The D / scratch / dispatch caches all assume the previous call's output has been consumed before the next call lands. Multi-stream / overlapped execution would race on the cached buffers — disable with MAGI_EVT_DISABLE_D_CACHE=1.
  • Forward-only. The pass refuses to fuse mm nodes that appear in a joint forward+backward graph (the mm output gets an extra "saved for backward" user that the escape detector correctly rejects).
  • Single best_idx per (N, K) module. The runner times candidates on whatever shape lands first inside a bucket; later shapes inside the same bucket reuse that pick. For very wide M-ranges within one bucket the best config can drift — bucket boundaries (256 / 2048) were chosen so the drift stays < 5 %, but a workload that mixes M = 257 and M = 2047 inside the medium bucket pays a slight sub-optimum for one of them.
  • Cold-start compile cost: a model with K distinct (IR, bucket, N, K) combinations needs K nvcc invocations on first run. Each EVT module is ~30–60 s; disk cache amortises across reruns of the same machine.
  • swiglu7 expects B as the underlying (N, K) row-major weight (the FX pass accepts permute / transpose / t views and recovers the operand); transposed-then-cloned weights aren't handled and will fall back.

Performance

image

Notes

MAGI_CUTLASS_ROOT

EVT / swiglu7 JIT builds resolve CUTLASS headers via torch.utils.cpp_extension.load(..., extra_include_paths=...). Set MAGI_CUTLASS_ROOT to the top-level CUTLASS checkout (the directory that contains include/, e.g. tools/util/include/). If unset, the code defaults to /opt/cutlass.

from ..pass_base import MagiInductorPass


class RemoveUselessOpsPass(MagiInductorPass):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RemoveUselessOpsPass -> EliminateIdentityViewCastPass
Note: also change the file name

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this file to the utils directory and update __init__.py

Comment thread Dockerfile

ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf"

# CUTLASS — source is always cloned (the magi_compiler EVT-fusion path
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Treat cutlass as third-party and provide install cmd cause users may install magi_compiler without docker.
Update commands in readme.md pls~

Comment on lines +86 to +89
# we lower fused chains to a CUTLASS Sm80EVT kernel. Toggled via
# PassConfig.enable_mm_epilogue_fusion (default True). The device
# check is independent — even with the flag on, non-sm_120 hosts
# don't register the pass since its FX walker would just no-op.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems better to support epilogue fusion for other arches, even if it achieves no performance gain.
Cause this fusion can also save memory.

# Propagate but with 128B-aligned stride matching what the
# CUDA impl actually returns.
new_val = val_last.new_empty_strided(
val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_aligned_n_stride not defined

@jiahy0825 jiahy0825 linked an issue May 8, 2026 that may be closed by this pull request
21 tasks
f.write(src)
os.replace(tmp_path, src_path)

cutlass_root = _cutlass_root()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find a better way to find the cutlass_root dir~

Comment on lines +271 to +276
if target in _TYPE_CONV_OPS:
node_to_ir[curr] = node_to_ir[curr.args[0]]
walk_seen.append(curr)
last_node = curr
last_ir = node_to_ir[curr]
curr = curr.next
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a compute chain like this:
matmul -> to_1(fp32) -> func1 -> to_2(fp8) -> func2
What happens to type conversion ops like to_1 and to_2?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants