Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Jan 6, 2026

Description

when using jnp.arange to initialize array, it could be ambiguous depending on the platform to use int64 or int32, this is to explicitly specify the dtype to eliminate ambiguity, which might have caused data to be misinterpreted and read as int32 when initialized as int64, causing a 50% data mismatch

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Make array init and ops specify int32 dtype for test_sort_chunk_by_index

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Greptile Summary

Fixed platform-dependent dtype ambiguity in reference test functions by explicitly specifying dtype=jnp.int32 for array initialization and adding .astype(jnp.int32) casts to cumsum operations. This prevents JAX from defaulting to int64 on some platforms while using int32 on others, which was causing a 50% comparison mismatch between test references and kernel outputs.

The fix targets three key functions:

  • reference_make_row_id_map: Added explicit int32 dtypes to expert_offsets array concatenation and token_idx array creation
  • reference_make_chunk_sort_map: Added explicit int32 dtypes to src_cumsum, dest_cumsum, and position_indices array operations

Remaining inconsistencies: Five additional jnp.arange() calls in helper functions (_reference_permute_impl and _reference_unpermute_impl) still lack explicit dtype specifications at lines 176, 182, 263, 275, and 287. While these are less critical (used for indexing operations where dtype matters less), they should be fixed for complete consistency.

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The changes correctly address the root cause of platform-dependent dtype ambiguity that was causing test failures. However, the fix is incomplete - five similar patterns remain unfixed in helper functions. While these remaining instances are less critical (used for indexing rather than data storage), they represent the same underlying issue and should be addressed for complete consistency.
  • tests/jax/test_permutation.py lines 176, 182, 263, 275, 287 need the same dtype fix applied

Important Files Changed

Filename Overview
tests/jax/test_permutation.py Added explicit dtype=jnp.int32 to array initialization and cumsum operations to fix platform-dependent type ambiguity. Some similar patterns remain without explicit dtype specifications.

Sequence Diagram

sequenceDiagram
    participant Test as test_permutation.py
    participant JAX as JAX Runtime
    participant Ref as reference functions
    
    Note over Test,Ref: Test Setup Phase
    Test->>Test: Generate routing_map/split_sizes
    Test->>JAX: Create arrays with jnp.array/arange
    Note over JAX: Platform-dependent dtype (int32 vs int64)
    
    Note over Test,Ref: Reference Implementation
    Test->>Ref: Call reference_make_row_id_map()
    Ref->>JAX: jnp.array([0], dtype=jnp.int32)
    Note over JAX: Explicit int32 dtype ✓
    Ref->>JAX: jnp.cumsum(...).astype(jnp.int32)
    Note over JAX: Explicit cast to int32 ✓
    Ref->>JAX: jnp.arange(num_tokens, dtype=jnp.int32)
    Note over JAX: Explicit int32 dtype ✓
    Ref-->>Test: Returns row_id_map (consistent int32)
    
    Note over Test,Ref: Chunk Sort Map Reference
    Test->>Ref: Call reference_make_chunk_sort_map()
    Ref->>JAX: jnp.array([0], dtype=jnp.int32)
    Note over JAX: Explicit int32 dtype ✓
    Ref->>JAX: jnp.cumsum(...).astype(jnp.int32)
    Note over JAX: Explicit cast to int32 ✓
    Ref->>JAX: jnp.arange(num_tokens, dtype=jnp.int32)
    Note over JAX: Explicit int32 dtype ✓
    Ref-->>Test: Returns chunk_sort_map (consistent int32)
    
    Note over Test,Ref: Test Execution
    Test->>Test: Compare kernel output vs reference
    Note over Test: Consistent dtypes prevent 50% mismatch
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. tests/jax/test_permutation.py, line 381 (link)

    syntax: the cumsum result should also be cast to int32 here for consistency

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. tests/jax/test_permutation.py, line 100-102 (link)

    logic: jnp.cumsum(tokens_per_expert)[:-1] should also be cast to int32 for consistency with the fix

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (5)

  1. tests/jax/test_permutation.py, line 176 (link)

    style: same dtype ambiguity here - jnp.arange(num_experts) should specify dtype=jnp.int32 for consistency

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. tests/jax/test_permutation.py, line 182 (link)

    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32

  3. tests/jax/test_permutation.py, line 263 (link)

    style: same dtype ambiguity - jnp.arange(num_experts) should specify dtype=jnp.int32

  4. tests/jax/test_permutation.py, line 275 (link)

    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32

  5. tests/jax/test_permutation.py, line 287 (link)

    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@tdophung
Copy link
Collaborator Author

tdophung commented Jan 6, 2026

/te-ci L0 jax

@tdophung
Copy link
Collaborator Author

tdophung commented Jan 6, 2026

These comemnts by Greptile is not necessary in this change :

Additional Comments (5)

  1. tests/jax/test_permutation.py, line 176 (link)
    style: same dtype ambiguity here - jnp.arange(num_experts) should specify dtype=jnp.int32 for consistency
    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
  2. tests/jax/test_permutation.py, line 182 (link)
    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32
  3. tests/jax/test_permutation.py, line 263 (link)
    style: same dtype ambiguity - jnp.arange(num_experts) should specify dtype=jnp.int32
  4. tests/jax/test_permutation.py, line 275 (link)
    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32
  5. tests/jax/test_permutation.py, line 287 (link)
    style: same dtype ambiguity - jnp.arange(num_tokens) should specify dtype=jnp.int32

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

  1. They are addressing other tests aside from test_sort_chunks_by_map which are not ffailing
  2. They are not failing because they all use the row_od_map from the kernel output (make_row_id_map), which is already int32, so there is no ambiguity here. While test_sort_chunks_by_index row_id_map could be int64 on certain platforms

@tdophung
Copy link
Collaborator Author

tdophung commented Jan 6, 2026

/te-ci jax L0

@tdophung
Copy link
Collaborator Author

tdophung commented Jan 6, 2026

/te-ci L0 jax

1 similar comment
@tdophung
Copy link
Collaborator Author

tdophung commented Jan 6, 2026

/te-ci L0 jax

@tdophung tdophung changed the title [Draft] Fix 50% comparison mismatch in sort_chunks_by_index Fix 50% comparison mismatch in sort_chunks_by_index Jan 7, 2026
@tdophung
Copy link
Collaborator Author

tdophung commented Jan 7, 2026

@tdophung tdophung merged commit 702fc5e into NVIDIA:main Jan 7, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants