-
Notifications
You must be signed in to change notification settings - Fork 601
Fix 50% comparison mismatch in sort_chunks_by_index #2566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: tdophung <[email protected]>
Greptile SummaryFixed platform-dependent dtype ambiguity in reference test functions by explicitly specifying The fix targets three key functions:
Remaining inconsistencies: Five additional Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_permutation.py
participant JAX as JAX Runtime
participant Ref as reference functions
Note over Test,Ref: Test Setup Phase
Test->>Test: Generate routing_map/split_sizes
Test->>JAX: Create arrays with jnp.array/arange
Note over JAX: Platform-dependent dtype (int32 vs int64)
Note over Test,Ref: Reference Implementation
Test->>Ref: Call reference_make_row_id_map()
Ref->>JAX: jnp.array([0], dtype=jnp.int32)
Note over JAX: Explicit int32 dtype ✓
Ref->>JAX: jnp.cumsum(...).astype(jnp.int32)
Note over JAX: Explicit cast to int32 ✓
Ref->>JAX: jnp.arange(num_tokens, dtype=jnp.int32)
Note over JAX: Explicit int32 dtype ✓
Ref-->>Test: Returns row_id_map (consistent int32)
Note over Test,Ref: Chunk Sort Map Reference
Test->>Ref: Call reference_make_chunk_sort_map()
Ref->>JAX: jnp.array([0], dtype=jnp.int32)
Note over JAX: Explicit int32 dtype ✓
Ref->>JAX: jnp.cumsum(...).astype(jnp.int32)
Note over JAX: Explicit cast to int32 ✓
Ref->>JAX: jnp.arange(num_tokens, dtype=jnp.int32)
Note over JAX: Explicit int32 dtype ✓
Ref-->>Test: Returns chunk_sort_map (consistent int32)
Note over Test,Ref: Test Execution
Test->>Test: Compare kernel output vs reference
Note over Test: Consistent dtypes prevent 50% mismatch
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
tests/jax/test_permutation.py, line 381 (link)syntax: the cumsum result should also be cast to
int32here for consistencyNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/jax/test_permutation.py, line 100-102 (link)logic:
jnp.cumsum(tokens_per_expert)[:-1]should also be cast toint32for consistency with the fix
1 file reviewed, 2 comments
Signed-off-by: tdophung <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (5)
-
tests/jax/test_permutation.py, line 176 (link)style: same dtype ambiguity here -
jnp.arange(num_experts)should specifydtype=jnp.int32for consistencyNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/jax/test_permutation.py, line 182 (link)style: same dtype ambiguity -
jnp.arange(num_tokens)should specifydtype=jnp.int32 -
tests/jax/test_permutation.py, line 263 (link)style: same dtype ambiguity -
jnp.arange(num_experts)should specifydtype=jnp.int32 -
tests/jax/test_permutation.py, line 275 (link)style: same dtype ambiguity -
jnp.arange(num_tokens)should specifydtype=jnp.int32 -
tests/jax/test_permutation.py, line 287 (link)style: same dtype ambiguity -
jnp.arange(num_tokens)should specifydtype=jnp.int32
1 file reviewed, 5 comments
|
/te-ci L0 jax |
|
These comemnts by Greptile is not necessary in this change :
|
|
/te-ci jax L0 |
|
/te-ci L0 jax |
1 similar comment
|
/te-ci L0 jax |
|
CI passed on these 2 manual triggering runs: |
Description
when using jnp.arange to initialize array, it could be ambiguous depending on the platform to use int64 or int32, this is to explicitly specify the dtype to eliminate ambiguity, which might have caused data to be misinterpreted and read as int32 when initialized as int64, causing a 50% data mismatch
Fixes # (issue)
Type of change
Changes
Make array init and ops specify int32 dtype for test_sort_chunk_by_index
Checklist: