-
Notifications
You must be signed in to change notification settings - Fork 77
Description
Hi team,
I noticed that the XProf overview page does not appear to report MXU utilization correctly during TPU inference.
Below is a minimal example that reproduces the issue, followed by a screenshot of the XProf output.
@jax.jit
def matrix_multiply(a, b):
return jnp.dot(a, b)
key = jax.random.PRNGKey(0)
a = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
b = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16)
old_result = matrix_multiply(a, b)
old_result.block_until_ready()
profile_dir = '/tmp/jax_profile'
os.makedirs(profile_dir, exist_ok=True)
jax.profiler.start_trace(profile_dir)
result = matrix_multiply(a, b)
result.block_until_ready()
jax.profiler.stop_trace()
Is this a known limitation? Are there any recommended fixes or workarounds?