Warn on mixed local/distributed variable shardings in JAX trainer#22256
Warn on mixed local/distributed variable shardings in JAX trainer#22256amitsrivastava78 wants to merge 2 commits intokeras-team:masterfrom
Conversation
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the JAX trainer by proactively identifying and warning users about inconsistent variable sharding configurations. By detecting mixed local and distributed shardings early, it helps prevent runtime errors that occur when models are built outside a distribution scope but compiled within one, thereby improving the robustness and user experience of distributed JAX training. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This is a great pull request that improves the developer experience for users working with JAX distribution strategies. The new _check_sharding_consistency helper provides an early and actionable warning for a common misconfiguration, which will save users significant debugging time. The warning message itself is a model of clarity, explaining the problem, its cause, and the solution with a code example. The accompanying tests are thorough, covering both DataParallel and ModelParallel strategies for both the warning and no-warning cases.
I have one minor suggestion to make the implementation slightly more memory-efficient and to clarify the intent of the code.
Add _check_sharding_consistency() helper to _get_state_sharding_spec() that detects when a model has a mix of SingleDeviceSharding (local) and mesh-aware variables. This happens when the model is built outside distribution.scope() and then compiled/trained inside it, leading to 'ValueError: Received incompatible devices for jitted computation'. The helper short-circuits on the first mismatch and emits an actionable warning with the offending variable name and instructions to fix. Also add parameterized tests covering DataParallel and ModelParallel for both the warning and no-warning cases.
b19842e to
1b5e615
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22256 +/- ##
==========================================
+ Coverage 82.77% 83.02% +0.25%
==========================================
Files 593 594 +1
Lines 63867 65046 +1179
Branches 10040 10180 +140
==========================================
+ Hits 52866 54006 +1140
+ Misses 8424 8418 -6
- Partials 2577 2622 +45
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Update warning to recommend set_distribution() as primary fix, with distribution.scope() as alternative - Rename trainer_sharding_fix_test.py -> trainer_test.py to match implementation file naming convention - Convert tests to class-based (JAXTrainerTest) with self.assertXXX - Make tests self-contained: inline model creation, adaptive device count - Remove dead XLA_FLAGS code that ran too late to create fake devices
Create model and call model.fit() inside dist.scope() so every variable
(trainable weights, optimizer slots, lazily-built metrics vars) is
initialized with NamedSharding tied to the full device mesh.
Previously we used set_distribution(dist) + model._symbolic_build() as
a workaround. The problem: metrics variables (total/count) are still
created lazily during the first JIT execution. If that happens outside
the distribution scope they get SingleDeviceSharding(device=0), which
causes _get_state_sharding_spec() to mix NamedSharding and
SingleDeviceSharding, setting JAX's JIT context mesh to {device 0} and
resulting in:
ValueError: Received incompatible devices for jitted computation.
Got argument with device ids [0,1,2,3] and jit context mesh [0].
Using dist.scope() as a context manager for both model creation and
model.fit() is the canonical Keras distribution pattern (documented in
keras.distribution and demonstrated in PR keras-team#22256). It guarantees all
variables are created inside the scope.
Also:
- Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Aes - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - A non-None object so this helper handles both cases uniformly.
- Re- Re- Re- Re- Re- Re- Re- Re- TestMixin (tearDown with
jax.clear_caches() is sufficient — scope() restores the previous
distribution itself).
- Remove unused set_distribution import side-effect from test bodies
(still imported at module level for _dist_scope helper).
Add _check_sharding_consistency() helper to _get_state_sharding_spec() that detects when a model has a mix of SingleDeviceSharding (local) and mesh-aware variables. This happens when the model is built outside distribution.scope() and then compiled/trained inside it, leading to 'ValueError: Received incompatible devices for jitted computation'.
Also add parameterized tests covering DataParallel and ModelParallel for both the warning and no-warning cases.