[rmsnorm] Add backward pass + forward store_rstd for training (PR 1/3, #769)#795
Open
jhinpan wants to merge 2 commits into
Open
[rmsnorm] Add backward pass + forward store_rstd for training (PR 1/3, #769)#795jhinpan wants to merge 2 commits into
jhinpan wants to merge 2 commits into
Conversation
Makes RMSNorm training-capable (PR 1 of 3 for ROCm#769): - build_rmsnorm_module gains store_rstd=False: when enabled, writes per-row rstd (1/RMS, fp32, shape (M,)) needed by backward. Default off keeps the existing launcher signature and all callers byte-for-byte unchanged; covers fast, generic, and small-N (N<=2048) paths. - build_rmsnorm_bwd_module: fused single kernel, one block per row. Pass 1 computes c1 = mean_N(x_hat*wdy); pass 2 stores dx = (wdy - x_hat*c1)*rstd and atomicAdds dw = dy*x_hat into an fp32 DWeight[N]. All reduction and weight-grad accumulation in fp32; only dx cast back to I/O dtype. - fp32 atomicAdd chosen for the cross-row dweight reduction after benchmarking it against a two-pass scratch+finalizer variant on MI355X: atomic never blows up and wins the large-N (real LLM hidden size) regime. - RMSNormFunction (torch.autograd.Function) + public rmsnorm(x, weight, eps), quack-aligned, in the kernel layer. Math matches quack rmsnorm_bwd_ref. - Tests: gradient checks vs torch.autograd across f32/f16/bf16, both N paths (incl. unaligned and N<=2048), plus an end-to-end rmsnorm() autograd test covering batched (>2D) reshape and grads on x + weight. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Fixes from the PR ROCm#795 code review: - eps is now baked into the forward kernel (build_rmsnorm_module / small-N builder gain an `eps` param) instead of being silently ignored in favor of the module EPS=1e-5. rmsnorm(x, w, eps=1e-6) now produces correct numerics; eps is part of the fwd compile-cache key. - Multi-GPU correctness: compiled-fn cache keys now include x.device, and compile+launch run under `with torch.cuda.device(x.device)` so a kernel built on cuda:0 is never launched on cuda:1 (previously faulted with hipErrorInvalidDevice + memory access fault). - Deduplicate the LLVM-ptr helper: hoist get_llvm_ptr into kernels_common.py (was a byte-for-byte copy of hgemm_splitk's helper, which CLAUDE.md says belongs in kernels_common) and use it. - Public rmsnorm() now asserts x.shape[-1] == weight length; rmsnorm_fwd/ rmsnorm_bwd assert contiguous inputs (make_buffer_tensor assumes row-major contiguous). - Document the backward kernel's deferred perf follow-ups (vectorized fast path + pass-1/pass-2 caching) instead of leaving them implicit. - Tests: add test_rmsnorm_eps_honored (eps must change output / match a torch ref at 1e-5/1e-6/1e-2) and test_rmsnorm_multi_gpu (marked multi_gpu; runs rmsnorm fwd+bwd on cuda:0 and cuda:1). Deferred (perf/altitude, not correctness; noted in code): vectorized backward fast path, pass-2 load caching, and the store_rstd launcher duplication (kept to preserve the byte-for-byte store_rstd=False path). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
Author
|
Ran a self-review pass and pushed fixes in
Deferred as noted in code (perf/altitude, not correctness): vectorized backward fast path, pass-2 load caching, and the Proceeding to PR 2 (fused-add/residual backward) and PR 3 (layernorm backward). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PR 1 of 3 for #769 — makes RMSNorm training-capable. This PR adds the core rmsnorm backward pass, the forward
rstdoutput it needs, and a quack-aligned autograd wrapper.Follow-ups (tracked in #769): PR 2 = fused-add/residual backward; PR 3 = layernorm backward (stretch).
What's included
rstd:build_rmsnorm_module(N, dtype, store_rstd=False). When enabled, writes per-rowrstd(1/RMS, fp32,(M,)) needed by backward. Default off ⇒ existing launcher signature and all callers byte-for-byte unchanged. Covers fast, generic, and small-N (N≤2048) paths.build_rmsnorm_bwd_module— fused single kernel, one block per row. Pass 1 computesc1 = mean_N(x_hat·wdy); pass 2 storesdx = (wdy − x_hat·c1)·rstdand atomicAddsdw = dy·x_hatinto an fp32DWeight[N]. All reduction + weight-grad accumulation in fp32; onlydxcast back to I/O dtype.RMSNormFunction+ publicrmsnorm(x, weight, eps), quack-aligned, in the kernel layer (per review discussion). Math matches quackrmsnorm_bwd_ref.Design note: dweight reduction = fp32 atomicAdd
The cross-row
dweight = Σ_rows(dy·x_hat)reduction was chosen by experiment, not assumption. We implemented both fp32atomicAddand a two-pass scratch+finalizer variant and benchmarked on MI355X (full data in #769):Shipping atomic as the default (simplest — no scratch buffer, no 2nd kernel). Scratch can be added later as an opt-in for determinism if needed.
Testing (MI355X, gfx950)
test_rmsnorm_backward: dx + dweight + kernel-rstdvstorch.autograd.grad, across f32/f16/bf16, both N paths (incl. unaligned64×2000andN≤2048).test_rmsnorm_autograd: end-to-end publicrmsnorm()path — grads on x + weight, and batched (>2D) reshape.Related
Part of #769. RFC #749.
🤖 Generated with Claude Code