feat: implement column parallel for lm head to improve performance.#1145
feat: implement column parallel for lm head to improve performance.#1145wxh571001500 wants to merge 3 commits intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the LmHead to utilize ColumnParallelLinearImpl and introduces support for vocabulary padding to ensure alignment during tensor parallel operations. The review feedback identifies several critical issues: a regression caused by hardcoding quantization arguments in a general-purpose linear layer constructor, memory inefficiencies when sharding padded tensors, and a potential bug in state dict lookups using incorrect keys. Additionally, there is a recommendation to deduplicate the vocabulary padding calculation logic into a shared utility to improve maintainability.
xllm/core/layers/common/linear.cpp
Outdated
| /*bias=*/false, | ||
| /*gather_output=*/true, | ||
| context.get_quant_args(), | ||
| QuantArgs{}, |
There was a problem hiding this comment.
Hardcoding QuantArgs{} in this constructor is a regression. This constructor is a general-purpose helper for ColumnParallelLinearImpl and should respect the quantization settings provided in the ModelContext. By forcing it to be unquantized, you break quantization for any other model layers that might use this constructor. Instead, keep this constructor generic and have the LmHead class explicitly pass QuantArgs{} when it instantiates the linear layer.
| QuantArgs{}, | |
| context.get_quant_args(), |
There was a problem hiding this comment.
Because LMhead does not support quantization.
| QuantArgs{}, // do not use quantization for lm_head! | ||
| context.get_parallel_args().tp_group_, | ||
| context.get_tensor_options())) {} | ||
| : ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(context)) {} |
There was a problem hiding this comment.
Since lm_head must remain unquantized for numerical stability, you should explicitly pass QuantArgs{} to the ColumnParallelLinearImpl constructor here, rather than relying on a modified generic constructor in linear.cpp.
: ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(
context.get_model_args().hidden_size(),
context.get_model_args().vocab_size(),
/*bias=*/false,
/*gather_output=*/true,
QuantArgs{}, // Explicitly disable quantization for lm_head
context.get_parallel_args().tp_group_,
context.get_tensor_options())) {}There was a problem hiding this comment.
The corresponding constructor function has been modified.
| if (padded_vocab_size > tensor.size(0)) { | ||
| mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size); | ||
| mutable_tensor = | ||
| shard_padded_tensor(mutable_tensor, dim, rank, world_size); |
There was a problem hiding this comment.
This implementation is highly memory-inefficient. It allocates a full padded vocabulary tensor on every rank before sharding it. For large vocabularies (e.g., 128k), this causes a significant and unnecessary host memory spike. When sharding on the vocabulary dimension (dim == 0), you should calculate the specific row range required for the current rank and only slice/pad that portion.
if (padded_vocab_size > tensor.size(0)) {
if (dim == 0) {
int64_t shard_size = padded_vocab_size / world_size;
int64_t start_idx = rank * shard_size;
int64_t end_idx = (rank + 1) * shard_size;
if (start_idx >= tensor.size(0)) {
mutable_tensor = torch::zeros({shard_size, tensor.size(1)}, tensor.options());
} else {
auto valid_part = tensor.slice(0, start_idx, std::min(end_idx, tensor.size(0)));
if (valid_part.size(0) < shard_size) {
mutable_tensor = torch::zeros({shard_size, tensor.size(1)}, tensor.options());
mutable_tensor.slice(0, 0, valid_part.size(0)).copy_(valid_part);
} else {
mutable_tensor = valid_part.clone();
}
}
} else {
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
mutable_tensor = shard_padded_tensor(mutable_tensor, dim, rank, world_size);
}
}
xllm/core/layers/common/linear.cpp
Outdated
| /*bias=*/false, | ||
| /*gather_output=*/true, | ||
| context.get_quant_args(), | ||
| QuantArgs{}, |
There was a problem hiding this comment.
why change context.get_quant_args() to QuantArgs{}?
There was a problem hiding this comment.
Because LMhead does not support quantization.
There was a problem hiding this comment.
Comments have been added
lmhead performance improved by 3%