Skip to content

Commit d313046

Browse files
committed
address greptile comment
Signed-off-by: tdophung <tdophung@nvidia.com>
1 parent 34db360 commit d313046

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/jax/test_permutation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def reference_make_row_id_map(
9898
# Compute total tokens per expert and expert offsets
9999
tokens_per_expert = jnp.sum(routing_map, axis=0)
100100
expert_offsets = jnp.concatenate(
101-
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1]]
101+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
102102
)
103103

104104
# Compute destination rows for all (token, expert) pairs

0 commit comments

Comments
 (0)