Skip to content

Warn on mixed local/distributed variable shardings in JAX trainer#22256

Open
amitsrivastava78 wants to merge 2 commits intokeras-team:masterfrom
amitsrivastava78:fix/warn-mixed-sharding-devices
Open

Warn on mixed local/distributed variable shardings in JAX trainer#22256
amitsrivastava78 wants to merge 2 commits intokeras-team:masterfrom
amitsrivastava78:fix/warn-mixed-sharding-devices

Conversation

@amitsrivastava78
Copy link
Collaborator

Add _check_sharding_consistency() helper to _get_state_sharding_spec() that detects when a model has a mix of SingleDeviceSharding (local) and mesh-aware variables. This happens when the model is built outside distribution.scope() and then compiled/trained inside it, leading to 'ValueError: Received incompatible devices for jitted computation'.
Also add parameterized tests covering DataParallel and ModelParallel for both the warning and no-warning cases.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the JAX trainer by proactively identifying and warning users about inconsistent variable sharding configurations. By detecting mixed local and distributed shardings early, it helps prevent runtime errors that occur when models are built outside a distribution scope but compiled within one, thereby improving the robustness and user experience of distributed JAX training.

Highlights

  • Sharding Consistency Check: Introduced a new helper function, _check_sharding_consistency, within the JAX trainer to detect when a model has a mix of SingleDeviceSharding (local) and mesh-aware variables. This check is crucial for preventing ValueError during JAX compilation.
  • Warning Mechanism: The new consistency check emits an actionable warning when a mixed sharding configuration is detected, guiding users to build their models inside distribution.scope() to resolve the issue.
  • Comprehensive Testing: Added a new test file with parameterized tests covering both DataParallel and ModelParallel distribution strategies. These tests verify that the warning is correctly issued in inconsistent scenarios and suppressed when shardings are consistent.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • keras/src/backend/jax/trainer.py
    • Introduced _check_sharding_consistency to identify inconsistent sharding configurations.
    • Integrated the sharding consistency check into _get_state_sharding_spec.
  • keras/src/backend/jax/trainer_sharding_fix_test.py
    • Added trainer_sharding_fix_test.py to test the new sharding consistency warning.
    • Included parameterized tests for DataParallel and ModelParallel strategies.
    • Verified warning emission when models are built outside the distribution scope.
    • Confirmed no warning when models are built inside the distribution scope.
Activity
  • The author, amitsrivastava78, created this pull request to address potential ValueError issues in JAX distributed training by adding a sharding consistency check and corresponding tests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a great pull request that improves the developer experience for users working with JAX distribution strategies. The new _check_sharding_consistency helper provides an early and actionable warning for a common misconfiguration, which will save users significant debugging time. The warning message itself is a model of clarity, explaining the problem, its cause, and the solution with a code example. The accompanying tests are thorough, covering both DataParallel and ModelParallel strategies for both the warning and no-warning cases.

I have one minor suggestion to make the implementation slightly more memory-efficient and to clarify the intent of the code.

Add _check_sharding_consistency() helper to _get_state_sharding_spec()
that detects when a model has a mix of SingleDeviceSharding (local) and
mesh-aware variables. This happens when the model is built outside
distribution.scope() and then compiled/trained inside it, leading to
'ValueError: Received incompatible devices for jitted computation'.

The helper short-circuits on the first mismatch and emits an actionable
warning with the offending variable name and instructions to fix.

Also add parameterized tests covering DataParallel and ModelParallel
for both the warning and no-warning cases.
@amitsrivastava78 amitsrivastava78 force-pushed the fix/warn-mixed-sharding-devices branch from b19842e to 1b5e615 Compare February 22, 2026 15:07
@codecov-commenter
Copy link

codecov-commenter commented Feb 22, 2026

Codecov Report

❌ Patch coverage is 88.23529% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.02%. Comparing base (0238793) to head (6b45a16).
⚠️ Report is 55 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/trainer.py 88.23% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22256      +/-   ##
==========================================
+ Coverage   82.77%   83.02%   +0.25%     
==========================================
  Files         593      594       +1     
  Lines       63867    65046    +1179     
  Branches    10040    10180     +140     
==========================================
+ Hits        52866    54006    +1140     
+ Misses       8424     8418       -6     
- Partials     2577     2622      +45     
Flag Coverage Δ
keras 82.85% <88.23%> (+0.25%) ⬆️
keras-jax 61.45% <88.23%> (-0.80%) ⬇️
keras-numpy 55.61% <0.00%> (-0.77%) ⬇️
keras-openvino 38.71% <0.00%> (+1.17%) ⬆️
keras-tensorflow 62.68% <0.00%> (-0.80%) ⬇️
keras-torch 61.53% <0.00%> (-0.78%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Update warning to recommend set_distribution() as primary fix,
  with distribution.scope() as alternative
- Rename trainer_sharding_fix_test.py -> trainer_test.py to match
  implementation file naming convention
- Convert tests to class-based (JAXTrainerTest) with self.assertXXX
- Make tests self-contained: inline model creation, adaptive device count
- Remove dead XLA_FLAGS code that ran too late to create fake devices
amitsrivastava78 added a commit to amitsrivastava78/keras that referenced this pull request Feb 24, 2026
Create model and call model.fit() inside dist.scope() so every variable
(trainable weights, optimizer slots, lazily-built metrics vars) is
initialized with NamedSharding tied to the full device mesh.

Previously we used set_distribution(dist) + model._symbolic_build() as
a workaround. The problem: metrics variables (total/count) are still
created lazily during the first JIT execution. If that happens outside
the distribution scope they get SingleDeviceSharding(device=0), which
causes _get_state_sharding_spec() to mix NamedSharding and
SingleDeviceSharding, setting JAX's JIT context mesh to {device 0} and
resulting in:

  ValueError: Received incompatible devices for jitted computation.
  Got argument with device ids [0,1,2,3] and jit context mesh [0].

Using dist.scope() as a context manager for both model creation and
model.fit() is the canonical Keras distribution pattern (documented in
keras.distribution and demonstrated in PR keras-team#22256). It guarantees all
variables are created inside the scope.

Also:
- Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - Aes - Add - Add - Add - Add - Add - Add - Add - Add - Add - Add - A non-None object so this helper handles both cases uniformly.
- Re- Re- Re- Re- Re- Re- Re- Re- TestMixin (tearDown with
  jax.clear_caches() is sufficient — scope() restores the previous
  distribution itself).
- Remove unused set_distribution import side-effect from test bodies
  (still imported at module level for _dist_scope helper).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants