Skip to content

[feat] Support multimodel mtp#14

Open
Jintao-Huang wants to merge 4 commits intomodelscope:mainfrom
Jintao-Huang:support_multimodel_mtp
Open

[feat] Support multimodel mtp#14
Jintao-Huang wants to merge 4 commits intomodelscope:mainfrom
Jintao-Huang:support_multimodel_mtp

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces decoder_input propagation through the Multi-Token Prediction (MTP) layers and adds a _get_embeddings helper method to handle sequence rolling and embedding computation. Several critical issues need to be addressed: the forward method signature in patcher.py is missing the decoder_input parameter, which will cause a NameError; the MultiTokenPrediction container requires patching to accept the new argument and avoid a TypeError; the make_viewless_tensor function is used without being imported; and the logic in _get_embeddings should be updated to preserve pre-computed decoder_input values instead of unconditionally overwriting them.

packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
decoder_input=decoder_input,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing decoder_input to self.mtp will likely result in a TypeError if self.mtp is the standard MultiTokenPrediction container from Megatron-Core, as its forward method does not accept this keyword argument. You need to ensure the container is also patched in patcher.py to accept and pass this argument to the individual MTP layers.

embedding=embedding,
packed_seq_params=packed_seq_params,
hidden_states=hidden_states,
decoder_input=decoder_input,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable decoder_input is used here as an argument, but it is not defined in the forward method's signature (lines 381-397). This will cause a NameError. You should update the forward signature to include decoder_input: Optional[torch.Tensor] = None. Additionally, as noted in the gpt_model.py review, the MultiTokenPrediction container's forward method must also be patched to propagate this argument.

packed_seq_params=packed_seq_params,
)
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The decoder_input argument is unconditionally overwritten by the result of embedding(...). To support multimodal models where decoder_input contains pre-computed embeddings (e.g., visual features), you should use the provided decoder_input if it is not None. If provided, it should be rolled to align with the shifted sequence; otherwise, fall back to computing it from input_ids.

Suggested change
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
if decoder_input is not None:
decoder_input, _ = roll_tensor(
decoder_input,
shifts=-1,
dims=0,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
else:
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)

# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)

hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

make_viewless_tensor is used here but it is not imported in this file. You should add it to the imports from megatron.core.utils at the top of the file (e.g., at line 20).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant