Skip to content

[DRAFT]: suggestions for triton kernels#979

Draft
stadlmax wants to merge 3 commits intoecmwf:mainfrom
stadlmax:fea-triton-improvements
Draft

[DRAFT]: suggestions for triton kernels#979
stadlmax wants to merge 3 commits intoecmwf:mainfrom
stadlmax:fea-triton-improvements

Conversation

@stadlmax
Copy link
Copy Markdown

Description

  • draft to discuss a few suggestions

What problem does this change solve?

  • accelerate conversions between COO and CSR/CSC structures
  • avoid calls to sort operations deterministic behavior is not desired via special Triton kernel

What issue or task does this change relate to?

TBD

Additional notes

  • makes custom kernels and custom operations compatible with torch.compile and AOT compilation workflows (torch.library.custom_op instead of torch.autograd.Function)
  • minor suggestions here and there

As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.

This function assumes a matrix layout of shape [H,C] for mask_H_C and H_C_off
"""

# default mask (assume no padded values)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default mask would be None, True AFAIK will be interpreted as a tile of size [1] and expanded accordingly. load/store operations still are faster as not every thread predicates these operations on its own. The difference should be marginal though. If inline-utility is preferred, I can also undo the change. For actual None masks, this helper doesn't work as Triton can't infer the type of a None value for this in-lined function.


@triton.jit
def _gt_fwd(
def _graphtransformer_fwd_kernel(
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names for Triton kernels are propagated. I might suggest something more informative than just _gt_fwd to keep profiles more informative.

C_mask = (C_off < C)
H_C_mask = H_mask & C_mask
H_PAD_VAL = 0.0
H_C_PAD_VAL = 0.0
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit padding values likely don't hurt. IIRC, there is no guarantee across Triton versions that padding values are 0 and especially for reductions this can introduce sources of incorrectness.

qk_scale: tl.constexpr = 1.0 / tl.sqrt(float(C))
dst_idx = tl.program_id(0).to(tl.int64)
head_idx = tl.program_id(1)
# option: batch_idx = tl.program_id(2)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given custom kernels, it could be worth thinking about support for batching (i.e. multiple batches of data tensors for the same graph structure).

e = tl.load(edge_ptr, mask=H_C_mask).to(tl.float32).reshape((H_pad, C_pad))
for eidx in range(neigh_start, neigh_end):
src_idx = tl.load(CSC_indices_ptr + eidx)
coo_idx = tl.load(CSC_map_to_coo_ptr + eidx)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indirect look-up in the main-loop to avoid explicit permutations of edge_attr for COO structures without introducing too much overhead for existing CSC structures

alpha_ij = tl.exp(qk - m_ij) # attention weight for current edge
correction = tl.exp(m_i - m_ij) # correction factor for previous accumulations
qk = qk - m_ij
if USE_FAST_EXP2:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXP2 trick following upstream iterations of FA kernels. Especially on hardware with an increasing gap of memory bandwidth and performance of special function units, this could bring minor gains.

@japols japols self-requested a review March 16, 2026 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

1 participant