diff --git a/xllm/core/framework/hf_model_loader.cpp b/xllm/core/framework/hf_model_loader.cpp index 425555108..ff460ae37 100644 --- a/xllm/core/framework/hf_model_loader.cpp +++ b/xllm/core/framework/hf_model_loader.cpp @@ -54,6 +54,17 @@ namespace xllm { namespace { +JsonReader normalize_config_torch_dtype(const JsonReader& reader) { + auto config = reader.data(); + if (!config.contains("torch_dtype") && config.contains("dtype")) { + config["torch_dtype"] = config["dtype"]; + } + + JsonReader normalized_reader; + normalized_reader.parse_text(config.dump()); + return normalized_reader; +} + bool is_compressed_tensors_fp8_scheme(const nlohmann::json& config) { auto type_it = config.find("type"); auto num_bits_it = config.find("num_bits"); @@ -725,7 +736,8 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) { << model_type; return false; } - model_args_loader(reader, &args_); + const JsonReader config_reader = normalize_config_torch_dtype(reader); + model_args_loader(config_reader, &args_); return true; } @@ -738,16 +750,20 @@ bool HFModelLoader::load_quant_args(const std::string& model_weights_path) { return false; } - if (!load_quant_cfg(reader, quant_args_)) { + const JsonReader config_reader = normalize_config_torch_dtype(reader); + + if (!load_quant_cfg(config_reader, quant_args_)) { return false; } // load quantization args for npu if exists - if (reader.contains("quantize")) { - quant_args_.quantize_type() = reader.value_or("quantize", ""); + if (config_reader.contains("quantize")) { + quant_args_.quantize_type() = + config_reader.value_or("quantize", ""); } - if (reader.contains("torch_dtype")) { - quant_args_.torch_dtype() = reader.value_or("torch_dtype", ""); + if (config_reader.contains("torch_dtype")) { + quant_args_.torch_dtype() = + config_reader.value_or("torch_dtype", ""); } // awq quantization args diff --git a/xllm/core/framework/hf_model_loader_test.cpp b/xllm/core/framework/hf_model_loader_test.cpp index b20c54a9e..3323dcbfb 100644 --- a/xllm/core/framework/hf_model_loader_test.cpp +++ b/xllm/core/framework/hf_model_loader_test.cpp @@ -17,6 +17,9 @@ limitations under the License. #include +#include +#include + #include "core/platform/device.h" namespace xllm { @@ -71,4 +74,44 @@ TEST(HFModelLoaderTest, KeepLegacyFp8ConfigUnchanged) { EXPECT_FALSE(quant_args.activation_dynamic()); } +class HFModelLoaderConfigTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = std::filesystem::temp_directory_path() / "hf_model_loader_test"; + if (std::filesystem::exists(test_dir_)) { + std::filesystem::remove_all(test_dir_); + } + std::filesystem::create_directories(test_dir_); + } + + void TearDown() override { + if (std::filesystem::exists(test_dir_)) { + std::filesystem::remove_all(test_dir_); + } + } + + void write_file(const std::filesystem::path& path, const std::string& data) { + std::ofstream file(path); + ASSERT_TRUE(file.is_open()); + file << data; + file.close(); + } + + std::filesystem::path test_dir_; +}; + +TEST_F(HFModelLoaderConfigTest, TreatsTopLevelDtypeAsTorchDtype) { + write_file(test_dir_ / "config.json", R"json( + { + "model_type": "qwen3", + "dtype": "bfloat16" + } + )json"); + write_file(test_dir_ / "model.safetensors", ""); + + HFModelLoader loader(test_dir_.string()); + EXPECT_EQ(loader.model_args().dtype(), "bfloat16"); + EXPECT_EQ(loader.quant_args().torch_dtype(), "bfloat16"); +} + } // namespace xllm