Skip to content

Add numerical-correctness test for Muon under ZeRO-1/2#8091

Open
whycoming wants to merge 1 commit into
deepspeedai:masterfrom
whycoming:test/muon-zero12-numerical-correctness
Open

Add numerical-correctness test for Muon under ZeRO-1/2#8091
whycoming wants to merge 1 commit into
deepspeedai:masterfrom
whycoming:test/muon-zero12-numerical-correctness

Conversation

@whycoming

Copy link
Copy Markdown

Summary

Adds a numerical-correctness regression test for the Muon optimizer under ZeRO-1/2. The existing Muon tests only assert that parameters changed, which cannot detect a wrong-but-nonzero update — exactly the failure mode of #7807, where reduce_scatter fed Newton-Schulz orthogonalization a partition slice instead of the full DP-averaged gradient. This complements the guard in #8090 by verifying the supported reduce_scatter: false path is actually numerically correct.

What the test does

TestMuonZero12NumericalCorrectness (in tests/unit/ops/muon/test_muon.py), world_size=2, parametrized over ZeRO stage [1, 2] and ns_method ['gram', 'standard']:

  1. Builds a model sized so a 2-D weight's flattened gradient straddles the rank-0/rank-1 partition boundary, and asserts this from the actual flattened ZeRO partition (optimizer.bit16_groups / bit16_groups_flat, accounting for alignment padding) — the exact case [BUG] Cross-partition parameters incorrectly updated when using ZeRO-1/ZeRO-2 with reduce_scatter=true and Muon optimizer #7807 corrupts.
  2. Runs one step on the supported reduce_scatter: false path with gradient_clipping=0 and loss_scale=1, so the applied master-weight update is exactly -lr * muon_update(grad).
  3. Compares that applied update against an independent reference that applies the real muon_update to the full DP-averaged gradient (using the library function, so Newton-Schulz rounding cancels), via relative Frobenius error.

A correct update differs from the reference by only a few percent; the partition-then-orthogonalize bug diverges by O(1) on the cross-partition weight, so the assertion uses a 0.40 threshold that cleanly separates the two.

Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2)

Relative error of the applied Muon update vs the full-gradient reference:

ns_method correct path (max over weights) buggy path (cross-partition weight)
gram (default) 0.068 0.603
standard 0.216 0.673

The cross-partition weight is the only one affected; wholly-owned weights are identical on both paths. With reduce_scatter: false the test passes for both stages and both ns_methods; injecting the bug (reduce_scatter: true, pre-guard) makes the cross-partition assertion fail by a wide margin — i.e. this test would have caught #7807.

Notes

Follow-up to #8090 (which adds the guard and closes #7807). Kept in the existing test_muon.py. Requires >=2 GPUs (fp16).

Refs #7807

cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana

…epspeedai#7807

The existing Muon tests only assert that parameters changed, which cannot detect
a wrong-but-nonzero update -- exactly the failure mode of deepspeedai#7807, where
reduce_scatter feeds Newton-Schulz orthogonalization a partition slice instead of
the full DP-averaged gradient.

This test runs the supported reduce_scatter=False path on 2 ranks for ZeRO stage
1 and 2 and both ns_method values (gram, standard), sized (via the actual
flattened ZeRO partition, accounting for alignment padding) so a 2D weight
straddles the rank boundary. It compares the applied Muon weight update against an
independent reference that applies the real muon_update to the full averaged
gradient. With gradient_clipping=0 and loss_scale=1 the applied update is exactly
-lr * muon_update(grad); a partition-then-orthogonalize bug diverges by O(1) on
the cross-partition weight (measured ~0.6-0.67), far above the fp16 Newton-Schulz
rounding of a correct update (measured up to ~0.07 for gram, ~0.22 for standard),
so the assertion uses a 0.40 relative-error threshold.

Requires >=2 GPUs (fp16).

Signed-off-by: whycoming <alwaysxd666@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Cross-partition parameters incorrectly updated when using ZeRO-1/ZeRO-2 with reduce_scatter=true and Muon optimizer

1 participant