Skip to content

fix: vectorize get_shuffle_matrix_a_row_indices to eliminate CPU contention#2935

Merged
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
youkaichao:fix/vectorize-shuffle-matrix-row-indices
Apr 1, 2026
Merged

fix: vectorize get_shuffle_matrix_a_row_indices to eliminate CPU contention#2935
aleozlx merged 1 commit intoflashinfer-ai:mainfrom
youkaichao:fix/vectorize-shuffle-matrix-row-indices

Conversation

@youkaichao
Copy link
Copy Markdown
Collaborator

@youkaichao youkaichao commented Apr 1, 2026

Problem

Closes #2934.

get_shuffle_matrix_a_row_indices in flashinfer/utils.py used a Python for loop iterating over all M rows of the weight matrix to build the shuffle index permutation. For large models this loop is slow (~0.5s per call on large weight matrices), and when multiple tensor-parallel ranks finish loading their weight shards at the same time, all ranks hit this CPU-bound loop simultaneously. This causes severe CPU contention — ranks that happen to start earlier monopolize the CPU cores and the stragglers are delayed by up to ~30 minutes, even though the final permutation is identical across all ranks (same M, same shuffle_block_size, same row_map).

The false dependency between ranks compounds the problem: all ranks are computing the same result independently, yet they serialize on CPU.

Fix

Replace the Python for-loop with vectorized PyTorch tensor operations:

# Before: O(M) Python loop — slow and causes CPU contention
for old_row in range(M):
    block_idx = old_row // shuffle_block_size
    row_in_block = old_row % shuffle_block_size
    mapped_row_in_block = row_map[row_in_block]
    new_row = block_idx * shuffle_block_size + mapped_row_in_block
    row_indices[new_row] = old_row

# After: vectorized — equivalent logic, no Python loop
old_rows = torch.arange(M, dtype=torch.long)
row_map_tensor = torch.tensor(row_map, dtype=torch.long)
mapped_rows = row_map_tensor[old_rows % shuffle_block_size]
new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows
row_indices[new_rows] = old_rows

The logic is identical — only the implementation changes from a scalar Python loop to bulk tensor operations, which is orders of magnitude faster for large M. After the fix, the function takes around 0.05s.

Summary by CodeRabbit

  • Refactor
    • Optimized internal tensor operations to improve performance efficiency.

…loop

Replace the Python for-loop over M rows with vectorized PyTorch tensor
operations. The old implementation iterated row-by-row in Python, causing
severe CPU contention when multiple tensor-parallel ranks simultaneously
called this function during model weight loading on large matrices.

Fixes flashinfer-ai#2934

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 1, 2026

📝 Walkthrough

Walkthrough

The get_shuffle_matrix_a_row_indices function in flashinfer/utils.py replaces a Python for-loop with vectorized PyTorch tensor operations. The function behavior and output remain identical; only the implementation mechanism changes from iterative row-by-row processing to batch tensor indexing, eliminating the loop overhead.

Changes

Cohort / File(s) Summary
Vectorization Refactor
flashinfer/utils.py
Replaced Python for-loop over M rows with fully vectorized tensor operations using torch.arange, modular arithmetic, and indexed assignment to compute shuffle permutation indices.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Poem

🐰 Hops of joy! No loops to tire,
Tensors dance in vectorized fire,
Shuffled rows fly swift and free,
PyTorch magic—fast as can be! 🚀

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: vectorizing get_shuffle_matrix_a_row_indices to eliminate CPU contention, which is the core objective of the PR.
Linked Issues check ✅ Passed The PR directly addresses issue #2934 by replacing the Python for-loop with vectorized PyTorch tensor operations (torch.arange, modulo/division, tensor indexing) as proposed, eliminating the CPU-bound bottleneck.
Out of Scope Changes check ✅ Passed The changes are narrowly scoped to vectorizing get_shuffle_matrix_a_row_indices in flashinfer/utils.py. The file summary (+7/-9 lines) reflects only the necessary refactoring to replace the loop with tensor operations.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The PR description comprehensively covers the problem, solution, and justification, following the template structure with clear sections and actionable detail.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces a slow Python loop with a vectorized PyTorch implementation in get_shuffle_matrix_a_row_indices to mitigate CPU contention. Feedback suggests explicitly specifying the device for the newly created tensors to match the input tensor and using torch.as_tensor for better efficiency.

Comment on lines +845 to +850
old_rows = torch.arange(M, dtype=torch.long)
row_map_tensor = torch.tensor(row_map, dtype=torch.long)
mapped_rows = row_map_tensor[old_rows % shuffle_block_size]
new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows
row_indices = torch.empty(M, dtype=torch.long)

for old_row in range(M):
block_idx = old_row // shuffle_block_size
row_in_block = old_row % shuffle_block_size
mapped_row_in_block = row_map[row_in_block]

new_row = block_idx * shuffle_block_size + mapped_row_in_block

row_indices[new_row] = old_row
row_indices[new_rows] = old_rows
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The vectorized implementation currently creates tensors on the default device. Explicitly using input_tensor.device ensures that the computation happens on the same device as the input data. If the input is on a GPU, this offloads the work from the CPU entirely, which further helps with the CPU contention issue described. Additionally, torch.as_tensor is generally preferred for converting sequences to tensors.

Suggested change
old_rows = torch.arange(M, dtype=torch.long)
row_map_tensor = torch.tensor(row_map, dtype=torch.long)
mapped_rows = row_map_tensor[old_rows % shuffle_block_size]
new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows
row_indices = torch.empty(M, dtype=torch.long)
for old_row in range(M):
block_idx = old_row // shuffle_block_size
row_in_block = old_row % shuffle_block_size
mapped_row_in_block = row_map[row_in_block]
new_row = block_idx * shuffle_block_size + mapped_row_in_block
row_indices[new_row] = old_row
row_indices[new_rows] = old_rows
old_rows = torch.arange(M, dtype=torch.long, device=input_tensor.device)
row_map_tensor = torch.as_tensor(row_map, dtype=torch.long, device=input_tensor.device)
mapped_rows = row_map_tensor[old_rows % shuffle_block_size]
new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows
row_indices = torch.empty(M, dtype=torch.long, device=input_tensor.device)
row_indices[new_rows] = old_rows

@yzh119 yzh119 added the run-ci label Apr 1, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 1, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !485 has been created, and the CI pipeline #47430477 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #47430477: canceled

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

@aleozlx aleozlx merged commit 5a906be into flashinfer-ai:main Apr 1, 2026
36 of 60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Accelerate get_shuffle_matrix_a_row_indices (Python for-loop is a runtime bottleneck)

4 participants