We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 34db360 commit d313046Copy full SHA for d313046
tests/jax/test_permutation.py
@@ -98,7 +98,7 @@ def reference_make_row_id_map(
98
# Compute total tokens per expert and expert offsets
99
tokens_per_expert = jnp.sum(routing_map, axis=0)
100
expert_offsets = jnp.concatenate(
101
- [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1]]
+ [jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
102
)
103
104
# Compute destination rows for all (token, expert) pairs
0 commit comments