Skip to content

feat: implement column parallel for lm head to improve performance.#1145

Open
wxh571001500 wants to merge 3 commits intojd-opensource:mainfrom
wxh571001500:main
Open

feat: implement column parallel for lm head to improve performance.#1145
wxh571001500 wants to merge 3 commits intojd-opensource:mainfrom
wxh571001500:main

Conversation

@wxh571001500
Copy link
Copy Markdown
Contributor

lmhead performance improved by 3%

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the LmHead to utilize ColumnParallelLinearImpl and introduces support for vocabulary padding to ensure alignment during tensor parallel operations. The review feedback identifies several critical issues: a regression caused by hardcoding quantization arguments in a general-purpose linear layer constructor, memory inefficiencies when sharding padded tensors, and a potential bug in state dict lookups using incorrect keys. Additionally, there is a recommendation to deduplicate the vocabulary padding calculation logic into a shared utility to improve maintainability.

/*bias=*/false,
/*gather_output=*/true,
context.get_quant_args(),
QuantArgs{},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding QuantArgs{} in this constructor is a regression. This constructor is a general-purpose helper for ColumnParallelLinearImpl and should respect the quantization settings provided in the ModelContext. By forcing it to be unquantized, you break quantization for any other model layers that might use this constructor. Instead, keep this constructor generic and have the LmHead class explicitly pass QuantArgs{} when it instantiates the linear layer.

Suggested change
QuantArgs{},
context.get_quant_args(),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because LMhead does not support quantization.

QuantArgs{}, // do not use quantization for lm_head!
context.get_parallel_args().tp_group_,
context.get_tensor_options())) {}
: ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(context)) {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since lm_head must remain unquantized for numerical stability, you should explicitly pass QuantArgs{} to the ColumnParallelLinearImpl constructor here, rather than relying on a modified generic constructor in linear.cpp.

      : ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(
            context.get_model_args().hidden_size(),
            context.get_model_args().vocab_size(),
            /*bias=*/false,
            /*gather_output=*/true,
            QuantArgs{},  // Explicitly disable quantization for lm_head
            context.get_parallel_args().tp_group_,
            context.get_tensor_options())) {}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding constructor function has been modified.

Comment on lines +245 to +248
if (padded_vocab_size > tensor.size(0)) {
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
mutable_tensor =
shard_padded_tensor(mutable_tensor, dim, rank, world_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation is highly memory-inefficient. It allocates a full padded vocabulary tensor on every rank before sharding it. For large vocabularies (e.g., 128k), this causes a significant and unnecessary host memory spike. When sharding on the vocabulary dimension (dim == 0), you should calculate the specific row range required for the current rank and only slice/pad that portion.

      if (padded_vocab_size > tensor.size(0)) {
        if (dim == 0) {
          int64_t shard_size = padded_vocab_size / world_size;
          int64_t start_idx = rank * shard_size;
          int64_t end_idx = (rank + 1) * shard_size;
          if (start_idx >= tensor.size(0)) {
            mutable_tensor = torch::zeros({shard_size, tensor.size(1)}, tensor.options());
          } else {
            auto valid_part = tensor.slice(0, start_idx, std::min(end_idx, tensor.size(0)));
            if (valid_part.size(0) < shard_size) {
              mutable_tensor = torch::zeros({shard_size, tensor.size(1)}, tensor.options());
              mutable_tensor.slice(0, 0, valid_part.size(0)).copy_(valid_part);
            } else {
              mutable_tensor = valid_part.clone();
            }
          }
        } else {
          mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
          mutable_tensor = shard_padded_tensor(mutable_tensor, dim, rank, world_size);
        }
      }

RobbieLeung
RobbieLeung previously approved these changes Mar 31, 2026
/*bias=*/false,
/*gather_output=*/true,
context.get_quant_args(),
QuantArgs{},
Copy link
Copy Markdown
Collaborator

@XuZhang99 XuZhang99 Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change context.get_quant_args() to QuantArgs{}?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because LMhead does not support quantization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments have been added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants