Skip to content

fix(models): revisions triton gt kernel to improve stability#1027

Draft
ssmmnn11 wants to merge 2 commits intomainfrom
fix/triton_gt_d
Draft

fix(models): revisions triton gt kernel to improve stability#1027
ssmmnn11 wants to merge 2 commits intomainfrom
fix/triton_gt_d

Conversation

@ssmmnn11
Copy link
Copy Markdown
Member

@ssmmnn11 ssmmnn11 commented Apr 6, 2026

Improve numerical stability of the Triton GT kernel. However, this results in a ~ 10% memory increase and a similar % in terms of performance degradation (pure kernel - less in a real training run). Tests.

blue: current triton kernel bf16
red: current triton kernel bf16, second try
green: resume blue curve with pyg compiled bf16
orange: pyg compiled bf16
violet: pyg compiled bf16, cast to float32
brown: revised triton kernel bf16

updated

test:

  • make the attention logits very large -> softmax becomes very sharp
  • every neighbor sends the same message -> output should no longer depend on the attention weights
  • the gradients dq and dk should be zero or very close to zero
  • check whether the backward pass still creates nonzero gradients

|dq| / |dk|

  • float32
    • GT before change: 0.000078 / 0.000703
    • GT now: 0.000091 / 0.000454
  • bfloat16
    • GT before change: 1.468 / 18.760
    • GT now: 0.321 / 3.615

@ssmmnn11 ssmmnn11 added models ATS Approval Not Needed No approval needed by ATS labels Apr 6, 2026
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Apr 6, 2026
@ssmmnn11 ssmmnn11 marked this pull request as draft April 6, 2026 18:54
@ssmmnn11 ssmmnn11 self-assigned this Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approval Not Needed No approval needed by ATS models

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

1 participant