Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces decoder_input propagation through the Multi-Token Prediction (MTP) layers and adds a _get_embeddings helper method to handle sequence rolling and embedding computation. Several critical issues need to be addressed: the forward method signature in patcher.py is missing the decoder_input parameter, which will cause a NameError; the MultiTokenPrediction container requires patching to accept the new argument and avoid a TypeError; the make_viewless_tensor function is used without being imported; and the logic in _get_embeddings should be updated to preserve pre-computed decoder_input values instead of unconditionally overwriting them.
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| embedding=self.embedding, | ||
| decoder_input=decoder_input, |
There was a problem hiding this comment.
Passing decoder_input to self.mtp will likely result in a TypeError if self.mtp is the standard MultiTokenPrediction container from Megatron-Core, as its forward method does not accept this keyword argument. You need to ensure the container is also patched in patcher.py to accept and pass this argument to the individual MTP layers.
| embedding=embedding, | ||
| packed_seq_params=packed_seq_params, | ||
| hidden_states=hidden_states, | ||
| decoder_input=decoder_input, |
There was a problem hiding this comment.
The variable decoder_input is used here as an argument, but it is not defined in the forward method's signature (lines 381-397). This will cause a NameError. You should update the forward signature to include decoder_input: Optional[torch.Tensor] = None. Additionally, as noted in the gpt_model.py review, the MultiTokenPrediction container's forward method must also be patched to propagate this argument.
| packed_seq_params=packed_seq_params, | ||
| ) | ||
| # embedding | ||
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) |
There was a problem hiding this comment.
The decoder_input argument is unconditionally overwritten by the result of embedding(...). To support multimodal models where decoder_input contains pre-computed embeddings (e.g., visual features), you should use the provided decoder_input if it is not None. If provided, it should be rolled to align with the shifted sequence; otherwise, fall back to computing it from input_ids.
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) | |
| if decoder_input is not None: | |
| decoder_input, _ = roll_tensor( | |
| decoder_input, | |
| shifts=-1, | |
| dims=0, | |
| cp_group=self.cp_group, | |
| packed_seq_params=packed_seq_params, | |
| ) | |
| else: | |
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) |
| # embedding | ||
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) | ||
|
|
||
| hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) |
No description provided.