Skip to content

feat: Add ascend backend#168

Open
weilinquan wants to merge 1 commit into
ByteDance-Seed:mainfrom
weilinquan:main
Open

feat: Add ascend backend#168
weilinquan wants to merge 1 commit into
ByteDance-Seed:mainfrom
weilinquan:main

Conversation

@weilinquan
Copy link
Copy Markdown

Summary

  1. Implement ascend backend include python primitives and conversion passes
  2. Adapt setup.py and cmake project for triton-ascend
  3. Add tutorials cases for fusion kernels, allgather+gemm, gemm+reduce_scatter, gemm+allreduce
  4. Add unittest for ascend backend in python/triton_dist/test/ascend
  5. Add triton-ascend and shmem as submodule
  6. Add common gemm-swizzle and comm-swizzle algorithm

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Apr 24, 2026

CLA assistant check
All committers have signed the CLA.

Comment on lines +269 to +273
def synchronized_print(message, rank, world_size):
"""Print message with rank prefix, synchronized across all ranks"""
dist.barrier() # Ensure all ranks reach here before any printing
print(f"[Rank {rank:02d}/{world_size:02d}] {message}", flush=True)
dist.barrier() # Ensure all ranks finish printing before continuing
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code, sorry forgot to remove this 😅

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines +146 to +147
if (ASCEND::pipeMap.contains(llvm::StringRef(customName))) {
customOp.setPipe(ASCEND::pipeMap.at(llvm::StringRef(customName)));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are doing lookup twice. You can use the find method and reuse the value to setPipe.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines +89 to +99
if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
if (intTy.getWidth() == 8) {
return "int8";
} else if (intTy.getWidth() == 16) {
return "int16";
} else if (intTy.getWidth() == 32) {
return "int32";
} else if (intTy.getWidth() == 64) {
return "int64";
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can getWidth() once and switch on the width

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines +112 to +125
if (std::is_same_v<DistOp, distributed::SymmAtOp> ||
symbolName == "aclshmem_ptr") {
auto elemTy = llvm::cast<triton::PointerType>(op->getOperand(0).getType())
.getPointeeType();
symbolName = "aclshmem_" + getTypeName(elemTy) + "_ptr";
} else if constexpr (std::is_same_v<DistOp, distributed::GetRankOp>) {
symbolName = "aclshmem_my_pe";
} else if constexpr (std::is_same_v<DistOp, distributed::GetNumRanksOp>) {
symbolName = "aclshmem_n_pes";
} else if constexpr (std::is_same_v<DistOp, distributed::NotifyOp>) {
auto notifyOp = llvm::cast<distributed::NotifyOp>(op);
auto typeName = getTypeName(notifyOp.getSignalVal().getType());
symbolName = "aclshmem_" + typeName + "_p";
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend using TypeSwitch here

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants