Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions xllm/core/layers/common/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,14 @@ torch::Tensor fp8_linear_forward(
} // namespace

ColumnParallelLinearImpl::ColumnParallelLinearImpl(const ModelContext& context)
: ColumnParallelLinearImpl(context.get_model_args().hidden_size(),
context.get_model_args().vocab_size(),
/*bias=*/false,
/*gather_output=*/true,
context.get_quant_args(),
context.get_parallel_args().tp_group_,
context.get_tensor_options()) {}
: ColumnParallelLinearImpl(
context.get_model_args().hidden_size(),
context.get_model_args().vocab_size(),
/*bias=*/false,
/*gather_output=*/true,
QuantArgs{}, // do not use quantization for lm_head
context.get_parallel_args().tp_group_,
context.get_tensor_options()) {}

// Linear layer with column parallelism.
ColumnParallelLinearImpl::ColumnParallelLinearImpl(
Expand Down Expand Up @@ -667,17 +668,6 @@ std::optional<torch::Tensor> QKVParallelLinearImpl::get_input_scale() const {
return std::nullopt;
}

// Linear layer with row parallelism.
RowParallelLinearImpl::RowParallelLinearImpl(const ModelContext& context)
: RowParallelLinearImpl(context.get_model_args().hidden_size(),
context.get_model_args().vocab_size(),
/*bias=*/false,
/*input_is_parallelized=*/false,
/*enable_result_reduction=*/true,
context.get_quant_args(),
context.get_parallel_args().tp_group_,
context.get_tensor_options()) {}

// Linear layer with row parallelism.
RowParallelLinearImpl::RowParallelLinearImpl(
int64_t in_features,
Expand Down
2 changes: 0 additions & 2 deletions xllm/core/layers/common/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ TORCH_MODULE(QKVParallelLinear);
// - -
class RowParallelLinearImpl : public torch::nn::Module {
public:
RowParallelLinearImpl(const ModelContext& context);

RowParallelLinearImpl(
int64_t in_features,
int64_t out_features,
Expand Down
21 changes: 4 additions & 17 deletions xllm/core/layers/common/lm_head.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,13 @@ limitations under the License.
namespace xllm {
namespace layer {

class LmHead : public torch::nn::ModuleHolder<RowParallelLinearImpl> {
class LmHead : public torch::nn::ModuleHolder<ColumnParallelLinearImpl> {
public:
using torch::nn::ModuleHolder<RowParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = RowParallelLinearImpl;
using torch::nn::ModuleHolder<ColumnParallelLinearImpl>::ModuleHolder;
using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl;

LmHead(const ModelContext& context)
: ModuleHolder(std::make_shared<RowParallelLinearImpl>(
// NOTE: Quantization should NOT be used for the final language
// modeling head (lm_head). The output logits must remain in high
// precision (typically bfloat16/float16) for numerical stability
// and correct evaluation of loss and predictions. Always use
// unquantized weights here.
context.get_model_args().hidden_size(),
context.get_model_args().vocab_size(),
/*bias=*/false,
/*input_is_parallelized=*/false,
/*enable_result_reduction=*/true,
QuantArgs{}, // do not use quantization for lm_head!
context.get_parallel_args().tp_group_,
context.get_tensor_options())) {}
: ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(context)) {}
};

} // namespace layer
Expand Down
117 changes: 116 additions & 1 deletion xllm/core/layers/npu/loader/base_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,120 @@ torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
LOG(FATAL) << "Unsupported dtype string: " << dtype_str;
}

at::Tensor BaseLoader::pad_vocab_tensor(const at::Tensor& tensor,
int64_t padded_vocab_size) const {
if (tensor.size(0) >= padded_vocab_size) {
return tensor;
}
at::Tensor padded_tensor =
torch::zeros({padded_vocab_size, tensor.size(1)}, tensor.options());
padded_tensor.slice(0, 0, tensor.size(0)) = tensor;
return padded_tensor;
}

at::Tensor BaseLoader::shard_padded_tensor(const at::Tensor& padded_tensor,
int dim,
int rank,
int world_size) const {
if (world_size <= 1) {
return padded_tensor;
}
auto chunks = padded_tensor.chunk(world_size, dim);
return chunks[rank];
}

void BaseLoader::set_weight_with_padding(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int64_t padded_vocab_size,
bool to_host) {
auto device = to_host ? at::kCPU : device_;
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
if (padded_vocab_size > tensor.size(0)) {
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
}
correct_tensor_dtype(mutable_tensor, tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
}

void BaseLoader::set_weight_with_padding(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int rank,
int world_size,
int64_t padded_vocab_size,
bool to_host) {
auto device = to_host ? at::kCPU : device_;
if (world_size <= 1) {
set_weight_with_padding(state_dict,
tensor_name,
weight_position,
dim,
padded_vocab_size,
to_host);
return;
}
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
if (padded_vocab_size > tensor.size(0)) {
// Memory-optimized path for vocabulary dimension sharding
if (dim == 0) {
int64_t shard_size = padded_vocab_size / world_size;
int64_t start_idx = rank * shard_size;
int64_t end_idx = (rank + 1) * shard_size;
if (start_idx >= tensor.size(0)) {
mutable_tensor =
torch::zeros({shard_size, tensor.size(1)}, tensor.options());
} else {
auto valid_part =
tensor.slice(0, start_idx, std::min(end_idx, tensor.size(0)));
if (valid_part.size(0) < shard_size) {
mutable_tensor =
torch::zeros({shard_size, tensor.size(1)}, tensor.options());
mutable_tensor.slice(0, 0, valid_part.size(0)).copy_(valid_part);
} else {
mutable_tensor = valid_part.clone();
}
}
} else {
// Non-vocabulary dimension: use original approach
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
mutable_tensor =
shard_padded_tensor(mutable_tensor, dim, rank, world_size);
}
} else {
mutable_tensor =
state_dict.get_sharded_tensor(tensor_name, dim, rank, world_size);
}
correct_tensor_dtype(mutable_tensor, tensor_name);
if (to_host) {
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
} else {
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
}
}
}
}

int64_t BaseLoader::get_padded_vocab_size(const ModelContext& context) const {
int64_t vocab_size = context.get_model_args().vocab_size();
int32_t local_tp_size = dp_local_tp_size_;
if (vocab_size > 0 && local_tp_size > 1 && vocab_size % local_tp_size != 0) {
return ((vocab_size + local_tp_size - 1) / local_tp_size) * local_tp_size;
}
return vocab_size;
}

} // namespace layer
} // namespace xllm
} // namespace xllm
28 changes: 27 additions & 1 deletion xllm/core/layers/npu/loader/base_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,33 @@ class BaseLoader {
int rank,
int world_size,
bool to_host = false);

void set_weight_with_padding(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int64_t padded_vocab_size,
bool to_host = false);

void set_weight_with_padding(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int rank,
int world_size,
int64_t padded_vocab_size,
bool to_host = false);

at::Tensor pad_vocab_tensor(const at::Tensor& tensor,
int64_t padded_vocab_size) const;

at::Tensor shard_padded_tensor(const at::Tensor& padded_tensor,
int dim,
int rank,
int world_size) const;

int64_t get_padded_vocab_size(const ModelContext& context) const;
};

} // namespace layer
} // namespace xllm
} // namespace xllm
29 changes: 22 additions & 7 deletions xllm/core/layers/npu/loader/lm_head_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,32 @@ LmHeadLoader::LmHeadLoader(uint64_t weight_count, const ModelContext& context)
: BaseLoader(weight_count, context) {
auto options = context.get_tensor_options();
at_weight_tensors_[0] = torch::zeros({1}).to(options);
vocab_size_ = context.get_model_args().vocab_size();
padded_vocab_size_ = get_padded_vocab_size(context);
}

void LmHeadLoader::load_state_dict(const StateDict& state_dict) {
if (cp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_);
} else if (dp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_);
if (cp_size_ > 1 || dp_size_ > 1) {
set_weight_with_padding(state_dict,
"weight",
0,
0,
dp_local_tp_rank_,
dp_local_tp_size_,
padded_vocab_size_,
false);
} else if (parallel_args_.world_size() > 1) {
set_weight_with_padding(state_dict,
"weight",
0,
0,
parallel_args_.rank(),
parallel_args_.world_size(),
padded_vocab_size_,
false);
} else {
set_weight(state_dict, "weight", 0, 1);
set_weight_with_padding(
state_dict, "weight", 0, 0, padded_vocab_size_, false);
}
}

Expand Down
4 changes: 4 additions & 0 deletions xllm/core/layers/npu/loader/lm_head_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class LmHeadLoader : public BaseLoader {

void load_state_dict(const StateDict& state_dict) override;
void verify_loaded_weights(const std::string& weight_str) const override;

private:
int64_t vocab_size_ = -1;
int64_t padded_vocab_size_ = -1;
};
} // namespace layer
} // namespace xllm
29 changes: 22 additions & 7 deletions xllm/core/layers/npu/loader/lm_head_manual_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,32 @@ LmHeadManualLoader::LmHeadManualLoader(uint64_t weight_count,
: BaseManualLoader(weight_count, context) {
auto options = context.get_tensor_options();
at_weight_tensors_[0] = torch::zeros({1}).to(options);
vocab_size_ = context.get_model_args().vocab_size();
padded_vocab_size_ = get_padded_vocab_size(context);
}

void LmHeadManualLoader::load_state_dict(const StateDict& state_dict) {
if (cp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_, true);
} else if (dp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_, true);
if (cp_size_ > 1 || dp_size_ > 1) {
set_weight_with_padding(state_dict,
"weight",
0,
0,
dp_local_tp_rank_,
dp_local_tp_size_,
padded_vocab_size_,
true);
} else if (parallel_args_.world_size() > 1) {
set_weight_with_padding(state_dict,
"weight",
0,
0,
parallel_args_.rank(),
parallel_args_.world_size(),
padded_vocab_size_,
true);
} else {
set_weight(state_dict, "weight", 0, 1, true);
set_weight_with_padding(
state_dict, "weight", 0, 0, padded_vocab_size_, true);
}
}

Expand Down
6 changes: 5 additions & 1 deletion xllm/core/layers/npu/loader/lm_head_manual_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class LmHeadManualLoader : public BaseManualLoader {

protected:
void merge_host_at_weights() override;

private:
int64_t vocab_size_ = -1;
int64_t padded_vocab_size_ = -1;
};
} // namespace layer
} // namespace xllm
} // namespace xllm
Loading
Loading