Skip to content

bugfix: make CP compatible with MTP#1150

Open
shifengmin wants to merge 3 commits intojd-opensource:release/v0.9.0from
shifengmin:fengmin/cp_mtp_pr
Open

bugfix: make CP compatible with MTP#1150
shifengmin wants to merge 3 commits intojd-opensource:release/v0.9.0from
shifengmin:fengmin/cp_mtp_pr

Conversation

@shifengmin
Copy link
Copy Markdown
Collaborator

No description provided.

@shifengmin shifengmin changed the title bugfix: CP compatibility with MTP bugfix: make CP compatible with MTP Mar 31, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Multi-Token Prediction (MTP) combined with Context Parallelism (CP) by adding mtp_shifted_token_ids to the input pipeline and establishing a dedicated cp_group for runtime collectives. Feedback indicates that the logic for constructing shifted token IDs is incorrect when padding is involved, and the cp_group initialization may lead to incorrect group memberships. Additionally, the embedding gather logic in the worker implementation is flawed for CP+MTP scenarios, and synchronous CPU transfers in the scheduler thread should be avoided to prevent performance degradation.

state_.mrope_positions_vec.reserve(sequences.size());
state_.block_tables_vec.reserve(sequences.size());
state_.acc_logprob_vec.reserve(sequences.size());
state_.mtp_shifted_token_ids.reserve(1000);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic num: 1000

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with a pre-defined const

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic num

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with a pre-defined const

CHECK(input_params.mtp_shifted_token_ids.defined());
CHECK_EQ(input_params.mtp_shifted_token_ids.numel(),
prefill_input.token_ids.numel());
prefill_input.token_ids = input_params.mtp_shifted_token_ids.clone();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need clone here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary, removed

state.extra_token_ids.end());
state_.mtp_shifted_token_ids.insert(state_.mtp_shifted_token_ids.end(),
state.mtp_shifted_token_ids.begin(),
state.mtp_shifted_token_ids.end());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the difference between extra_token_ids and mtp_shifted_token_ids ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtp_shifted_token_ids: represents the sequence that has been left-shifted and padded with -1 according to MTP prefill input prepare rules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants