[DRAFT]: suggestions for triton kernels#979
Conversation
for more information, see https://pre-commit.ci
| This function assumes a matrix layout of shape [H,C] for mask_H_C and H_C_off | ||
| """ | ||
|
|
||
| # default mask (assume no padded values) |
There was a problem hiding this comment.
The default mask would be None, True AFAIK will be interpreted as a tile of size [1] and expanded accordingly. load/store operations still are faster as not every thread predicates these operations on its own. The difference should be marginal though. If inline-utility is preferred, I can also undo the change. For actual None masks, this helper doesn't work as Triton can't infer the type of a None value for this in-lined function.
|
|
||
| @triton.jit | ||
| def _gt_fwd( | ||
| def _graphtransformer_fwd_kernel( |
There was a problem hiding this comment.
Names for Triton kernels are propagated. I might suggest something more informative than just _gt_fwd to keep profiles more informative.
| C_mask = (C_off < C) | ||
| H_C_mask = H_mask & C_mask | ||
| H_PAD_VAL = 0.0 | ||
| H_C_PAD_VAL = 0.0 |
There was a problem hiding this comment.
Explicit padding values likely don't hurt. IIRC, there is no guarantee across Triton versions that padding values are 0 and especially for reductions this can introduce sources of incorrectness.
| qk_scale: tl.constexpr = 1.0 / tl.sqrt(float(C)) | ||
| dst_idx = tl.program_id(0).to(tl.int64) | ||
| head_idx = tl.program_id(1) | ||
| # option: batch_idx = tl.program_id(2) |
There was a problem hiding this comment.
Given custom kernels, it could be worth thinking about support for batching (i.e. multiple batches of data tensors for the same graph structure).
| e = tl.load(edge_ptr, mask=H_C_mask).to(tl.float32).reshape((H_pad, C_pad)) | ||
| for eidx in range(neigh_start, neigh_end): | ||
| src_idx = tl.load(CSC_indices_ptr + eidx) | ||
| coo_idx = tl.load(CSC_map_to_coo_ptr + eidx) |
There was a problem hiding this comment.
indirect look-up in the main-loop to avoid explicit permutations of edge_attr for COO structures without introducing too much overhead for existing CSC structures
| alpha_ij = tl.exp(qk - m_ij) # attention weight for current edge | ||
| correction = tl.exp(m_i - m_ij) # correction factor for previous accumulations | ||
| qk = qk - m_ij | ||
| if USE_FAST_EXP2: |
There was a problem hiding this comment.
EXP2 trick following upstream iterations of FA kernels. Especially on hardware with an increasing gap of memory bandwidth and performance of special function units, this could bring minor gains.
Description
What problem does this change solve?
What issue or task does this change relate to?
TBD
Additional notes
torch.library.custom_opinstead oftorch.autograd.Function)As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/
By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.