fix: vectorize get_shuffle_matrix_a_row_indices to eliminate CPU contention#2935
Conversation
…loop Replace the Python for-loop over M rows with vectorized PyTorch tensor operations. The old implementation iterated row-by-row in Python, causing severe CPU contention when multiple tensor-parallel ranks simultaneously called this function during model weight loading on large matrices. Fixes flashinfer-ai#2934 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request replaces a slow Python loop with a vectorized PyTorch implementation in get_shuffle_matrix_a_row_indices to mitigate CPU contention. Feedback suggests explicitly specifying the device for the newly created tensors to match the input tensor and using torch.as_tensor for better efficiency.
| old_rows = torch.arange(M, dtype=torch.long) | ||
| row_map_tensor = torch.tensor(row_map, dtype=torch.long) | ||
| mapped_rows = row_map_tensor[old_rows % shuffle_block_size] | ||
| new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows | ||
| row_indices = torch.empty(M, dtype=torch.long) | ||
|
|
||
| for old_row in range(M): | ||
| block_idx = old_row // shuffle_block_size | ||
| row_in_block = old_row % shuffle_block_size | ||
| mapped_row_in_block = row_map[row_in_block] | ||
|
|
||
| new_row = block_idx * shuffle_block_size + mapped_row_in_block | ||
|
|
||
| row_indices[new_row] = old_row | ||
| row_indices[new_rows] = old_rows |
There was a problem hiding this comment.
The vectorized implementation currently creates tensors on the default device. Explicitly using input_tensor.device ensures that the computation happens on the same device as the input data. If the input is on a GPU, this offloads the work from the CPU entirely, which further helps with the CPU contention issue described. Additionally, torch.as_tensor is generally preferred for converting sequences to tensors.
| old_rows = torch.arange(M, dtype=torch.long) | |
| row_map_tensor = torch.tensor(row_map, dtype=torch.long) | |
| mapped_rows = row_map_tensor[old_rows % shuffle_block_size] | |
| new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows | |
| row_indices = torch.empty(M, dtype=torch.long) | |
| for old_row in range(M): | |
| block_idx = old_row // shuffle_block_size | |
| row_in_block = old_row % shuffle_block_size | |
| mapped_row_in_block = row_map[row_in_block] | |
| new_row = block_idx * shuffle_block_size + mapped_row_in_block | |
| row_indices[new_row] = old_row | |
| row_indices[new_rows] = old_rows | |
| old_rows = torch.arange(M, dtype=torch.long, device=input_tensor.device) | |
| row_map_tensor = torch.as_tensor(row_map, dtype=torch.long, device=input_tensor.device) | |
| mapped_rows = row_map_tensor[old_rows % shuffle_block_size] | |
| new_rows = (old_rows // shuffle_block_size) * shuffle_block_size + mapped_rows | |
| row_indices = torch.empty(M, dtype=torch.long, device=input_tensor.device) | |
| row_indices[new_rows] = old_rows |
|
/bot run |
|
[CANCELING] Pipeline #47430477: canceled |
Problem
Closes #2934.
get_shuffle_matrix_a_row_indicesinflashinfer/utils.pyused a Pythonforloop iterating over allMrows of the weight matrix to build the shuffle index permutation. For large models this loop is slow (~0.5s per call on large weight matrices), and when multiple tensor-parallel ranks finish loading their weight shards at the same time, all ranks hit this CPU-bound loop simultaneously. This causes severe CPU contention — ranks that happen to start earlier monopolize the CPU cores and the stragglers are delayed by up to ~30 minutes, even though the final permutation is identical across all ranks (sameM, sameshuffle_block_size, samerow_map).The false dependency between ranks compounds the problem: all ranks are computing the same result independently, yet they serialize on CPU.
Fix
Replace the Python for-loop with vectorized PyTorch tensor operations:
The logic is identical — only the implementation changes from a scalar Python loop to bulk tensor operations, which is orders of magnitude faster for large
M. After the fix, the function takes around 0.05s.Summary by CodeRabbit