Add numerical-correctness test for Muon under ZeRO-1/2#8091
Open
whycoming wants to merge 1 commit into
Open
Conversation
…epspeedai#7807 The existing Muon tests only assert that parameters changed, which cannot detect a wrong-but-nonzero update -- exactly the failure mode of deepspeedai#7807, where reduce_scatter feeds Newton-Schulz orthogonalization a partition slice instead of the full DP-averaged gradient. This test runs the supported reduce_scatter=False path on 2 ranks for ZeRO stage 1 and 2 and both ns_method values (gram, standard), sized (via the actual flattened ZeRO partition, accounting for alignment padding) so a 2D weight straddles the rank boundary. It compares the applied Muon weight update against an independent reference that applies the real muon_update to the full averaged gradient. With gradient_clipping=0 and loss_scale=1 the applied update is exactly -lr * muon_update(grad); a partition-then-orthogonalize bug diverges by O(1) on the cross-partition weight (measured ~0.6-0.67), far above the fp16 Newton-Schulz rounding of a correct update (measured up to ~0.07 for gram, ~0.22 for standard), so the assertion uses a 0.40 relative-error threshold. Requires >=2 GPUs (fp16). Signed-off-by: whycoming <alwaysxd666@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a numerical-correctness regression test for the Muon optimizer under ZeRO-1/2. The existing Muon tests only assert that parameters changed, which cannot detect a wrong-but-nonzero update — exactly the failure mode of #7807, where
reduce_scatterfed Newton-Schulz orthogonalization a partition slice instead of the full DP-averaged gradient. This complements the guard in #8090 by verifying the supportedreduce_scatter: falsepath is actually numerically correct.What the test does
TestMuonZero12NumericalCorrectness(intests/unit/ops/muon/test_muon.py),world_size=2, parametrized over ZeRO stage[1, 2]andns_method ['gram', 'standard']:optimizer.bit16_groups/bit16_groups_flat, accounting for alignment padding) — the exact case [BUG] Cross-partition parameters incorrectly updated when using ZeRO-1/ZeRO-2 with reduce_scatter=true and Muon optimizer #7807 corrupts.reduce_scatter: falsepath withgradient_clipping=0andloss_scale=1, so the applied master-weight update is exactly-lr * muon_update(grad).muon_updateto the full DP-averaged gradient (using the library function, so Newton-Schulz rounding cancels), via relative Frobenius error.A correct update differs from the reference by only a few percent; the partition-then-orthogonalize bug diverges by O(1) on the cross-partition weight, so the assertion uses a 0.40 threshold that cleanly separates the two.
Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2)
Relative error of the applied Muon update vs the full-gradient reference:
The cross-partition weight is the only one affected; wholly-owned weights are identical on both paths. With
reduce_scatter: falsethe test passes for both stages and both ns_methods; injecting the bug (reduce_scatter: true, pre-guard) makes the cross-partition assertion fail by a wide margin — i.e. this test would have caught #7807.Notes
Follow-up to #8090 (which adds the guard and closes #7807). Kept in the existing
test_muon.py. Requires >=2 GPUs (fp16).Refs #7807
cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana