Skip to content

Commit d46d5db

Browse files
[JAX] Handle meshs set with jax.set_mesh (#2532)
* Handle meshs set with jax.set_mesh Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6fd6209 commit d46d5db

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

transformer_engine/jax/sharding.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
W_JOINED_AXES = "nvte_w_joined"
3838

3939

40+
def _get_mesh():
41+
# Handle Mesh's set via `with mesh:`
42+
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
43+
if mesh is not None and not mesh.empty:
44+
return mesh
45+
# Handle Mesh's set via `jax.set_mesh(mesh)`
46+
return jax.sharding.get_abstract_mesh()
47+
48+
4049
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
4150
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
4251
return mesh.shape[resource], resource
@@ -63,15 +72,15 @@ def is_mesh_available() -> bool:
6372
"""
6473
Check if a physical mesh is available.
6574
"""
66-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
75+
mesh = _get_mesh()
6776
return mesh is not None and not mesh.empty
6877

6978

7079
def get_sharding_map_logic_axis_to_mesh_axis():
7180
"""
7281
Generate a dict to map logical axes to mesh axes.
7382
"""
74-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
83+
mesh = _get_mesh()
7584
if mesh is None or mesh.empty:
7685
# If no mesh is defined, return an empty dict and do not require a MeshResource context to be present
7786
return {}
@@ -130,7 +139,7 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
130139
if pspec is None:
131140
return x
132141

133-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
142+
mesh = _get_mesh()
134143
if mesh.empty:
135144
return x
136145

@@ -211,7 +220,7 @@ def get_all_mesh_axes():
211220
"""
212221
Get all name of mesh axes
213222
"""
214-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
223+
mesh = _get_mesh()
215224
return mesh.axis_names
216225

217226

@@ -251,7 +260,7 @@ def get_num_devices_in_mesh(mesh=None):
251260
by the global mesh.
252261
"""
253262
if mesh is None:
254-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
263+
mesh = _get_mesh()
255264
if mesh.empty:
256265
return 1
257266
return np.prod(list(mesh.shape.values()))
@@ -264,7 +273,7 @@ def get_mesh_axis_size(axis, mesh=None):
264273
by the global mesh.
265274
"""
266275
if mesh is None:
267-
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
276+
mesh = _get_mesh()
268277

269278
if axis is None:
270279
return 1

0 commit comments

Comments
 (0)