Skip to content

[rmsnorm] Add backward pass + forward store_rstd for training (PR 1/3, #769)#795

Open
jhinpan wants to merge 2 commits into
ROCm:mainfrom
jhinpan:feat/rmsnorm-backward-769
Open

[rmsnorm] Add backward pass + forward store_rstd for training (PR 1/3, #769)#795
jhinpan wants to merge 2 commits into
ROCm:mainfrom
jhinpan:feat/rmsnorm-backward-769

Conversation

@jhinpan

@jhinpan jhinpan commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

Summary

PR 1 of 3 for #769 — makes RMSNorm training-capable. This PR adds the core rmsnorm backward pass, the forward rstd output it needs, and a quack-aligned autograd wrapper.

Follow-ups (tracked in #769): PR 2 = fused-add/residual backward; PR 3 = layernorm backward (stretch).

What's included

  • Forward saves rstd: build_rmsnorm_module(N, dtype, store_rstd=False). When enabled, writes per-row rstd (1/RMS, fp32, (M,)) needed by backward. Default off ⇒ existing launcher signature and all callers byte-for-byte unchanged. Covers fast, generic, and small-N (N≤2048) paths.
  • Backward kernel: build_rmsnorm_bwd_module — fused single kernel, one block per row. Pass 1 computes c1 = mean_N(x_hat·wdy); pass 2 stores dx = (wdy − x_hat·c1)·rstd and atomicAdds dw = dy·x_hat into an fp32 DWeight[N]. All reduction + weight-grad accumulation in fp32; only dx cast back to I/O dtype.
  • Autograd wrapper: RMSNormFunction + public rmsnorm(x, weight, eps), quack-aligned, in the kernel layer (per review discussion). Math matches quack rmsnorm_bwd_ref.

Design note: dweight reduction = fp32 atomicAdd

The cross-row dweight = Σ_rows(dy·x_hat) reduction was chosen by experiment, not assumption. We implemented both fp32 atomicAdd and a two-pass scratch+finalizer variant and benchmarked on MI355X (full data in #769):

  • atomic never blows up (caps ~3.3ms even at M=131072) and wins the large-N regime (4096–8192, i.e. real LLM hidden sizes);
  • scratch wins mid-M/small-N but collapses at large M·N with a fixed grid.

Shipping atomic as the default (simplest — no scratch buffer, no 2nd kernel). Scratch can be added later as an opt-in for determinism if needed.

Testing (MI355X, gfx950)

  • Forward regression: unchanged, all pass.
  • test_rmsnorm_backward: dx + dweight + kernel-rstd vs torch.autograd.grad, across f32/f16/bf16, both N paths (incl. unaligned 64×2000 and N≤2048).
  • test_rmsnorm_autograd: end-to-end public rmsnorm() path — grads on x + weight, and batched (>2D) reshape.
  • Full non-large suite: 8/8 passed. Verified dweight has no accumulation leak (stable across repeated calls; only atomic-order fp32 noise ~1e-5 f32).

Related

Part of #769. RFC #749.

🤖 Generated with Claude Code

Makes RMSNorm training-capable (PR 1 of 3 for ROCm#769):

- build_rmsnorm_module gains store_rstd=False: when enabled, writes per-row
  rstd (1/RMS, fp32, shape (M,)) needed by backward. Default off keeps the
  existing launcher signature and all callers byte-for-byte unchanged; covers
  fast, generic, and small-N (N<=2048) paths.
- build_rmsnorm_bwd_module: fused single kernel, one block per row. Pass 1
  computes c1 = mean_N(x_hat*wdy); pass 2 stores dx = (wdy - x_hat*c1)*rstd
  and atomicAdds dw = dy*x_hat into an fp32 DWeight[N]. All reduction and
  weight-grad accumulation in fp32; only dx cast back to I/O dtype.
- fp32 atomicAdd chosen for the cross-row dweight reduction after
  benchmarking it against a two-pass scratch+finalizer variant on MI355X:
  atomic never blows up and wins the large-N (real LLM hidden size) regime.
- RMSNormFunction (torch.autograd.Function) + public rmsnorm(x, weight, eps),
  quack-aligned, in the kernel layer. Math matches quack rmsnorm_bwd_ref.
- Tests: gradient checks vs torch.autograd across f32/f16/bf16, both N paths
  (incl. unaligned and N<=2048), plus an end-to-end rmsnorm() autograd test
  covering batched (>2D) reshape and grads on x + weight.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings July 3, 2026 04:50

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

Fixes from the PR ROCm#795 code review:

- eps is now baked into the forward kernel (build_rmsnorm_module /
  small-N builder gain an `eps` param) instead of being silently
  ignored in favor of the module EPS=1e-5. rmsnorm(x, w, eps=1e-6) now
  produces correct numerics; eps is part of the fwd compile-cache key.
- Multi-GPU correctness: compiled-fn cache keys now include x.device,
  and compile+launch run under `with torch.cuda.device(x.device)` so a
  kernel built on cuda:0 is never launched on cuda:1 (previously faulted
  with hipErrorInvalidDevice + memory access fault).
- Deduplicate the LLVM-ptr helper: hoist get_llvm_ptr into
  kernels_common.py (was a byte-for-byte copy of hgemm_splitk's helper,
  which CLAUDE.md says belongs in kernels_common) and use it.
- Public rmsnorm() now asserts x.shape[-1] == weight length; rmsnorm_fwd/
  rmsnorm_bwd assert contiguous inputs (make_buffer_tensor assumes
  row-major contiguous).
- Document the backward kernel's deferred perf follow-ups (vectorized
  fast path + pass-1/pass-2 caching) instead of leaving them implicit.
- Tests: add test_rmsnorm_eps_honored (eps must change output / match a
  torch ref at 1e-5/1e-6/1e-2) and test_rmsnorm_multi_gpu (marked
  multi_gpu; runs rmsnorm fwd+bwd on cuda:0 and cuda:1).

Deferred (perf/altitude, not correctness; noted in code): vectorized
backward fast path, pass-2 load caching, and the store_rstd launcher
duplication (kept to preserve the byte-for-byte store_rstd=False path).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@jhinpan

jhinpan commented Jul 3, 2026

Copy link
Copy Markdown
Contributor Author

Ran a self-review pass and pushed fixes in 0c5f94c:

  • eps was silently ignored (kernel hardcoded 1e-5) → now baked into the forward kernel and part of the compile-cache key. rmsnorm(x, w, eps=1e-6) now matches a torch ref (err ~2e-7). (verified on MI355X)
  • Multi-GPU fault: compiled-fn cache keys omitted device, so a cuda:0 kernel reused on cuda:1 crashed with hipErrorInvalidDevice + memory access fault → cache key now includes x.device and compile/launch run under torch.cuda.device(x.device). (verified cuda:0 → cuda:1)
  • Dedup: hoisted get_llvm_ptr into kernels_common.py (was a byte-for-byte copy of hgemm_splitk's; CLAUDE.md says it belongs there).
  • Guards: public rmsnorm() now asserts x.shape[-1] == weight length; rmsnorm_fwd/bwd assert contiguous inputs.
  • Tests: added test_rmsnorm_eps_honored + test_rmsnorm_multi_gpu (marked multi_gpu). Full non-large suite: 10/10 pass.

Deferred as noted in code (perf/altitude, not correctness): vectorized backward fast path, pass-2 load caching, and the store_rstd launcher duplication (kept to preserve the byte-for-byte store_rstd=False path). These are follow-ups.

Proceeding to PR 2 (fused-add/residual backward) and PR 3 (layernorm backward).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants