Skip to content

Commit 02b747e

Browse files
authored
Merge pull request #1057 from Todysheep/dev
feat: 更新LLMRequest类以支持自定义参数,更新payload键值添加逻辑,兼容不支持某些键值的api
2 parents a3a3d87 + 7961a1f commit 02b747e

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

src/llm_models/utils_model.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,21 @@ def __init__(self, model: dict, **kwargs):
124124
self.model_name: str = model["name"]
125125
self.params = kwargs
126126

127-
self.enable_thinking = model.get("enable_thinking", False)
127+
self.enable_thinking = model.get("enable_thinking", None)
128128
self.temp = model.get("temp", 0.7)
129-
self.thinking_budget = model.get("thinking_budget", 4096)
129+
self.thinking_budget = model.get("thinking_budget", None)
130130
self.stream = model.get("stream", False)
131131
self.pri_in = model.get("pri_in", 0)
132132
self.pri_out = model.get("pri_out", 0)
133133
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
134134
# print(f"max_tokens: {self.max_tokens}")
135+
custom_params_str = model.get("custom_params", "{}")
136+
try:
137+
self.custom_params = json.loads(custom_params_str)
138+
except json.JSONDecodeError as e:
139+
logger.error(f"Invalid JSON in custom_params for model '{self.model_name}': {custom_params_str}")
140+
self.custom_params = {}
141+
135142

136143
# 获取数据库实例
137144
self._init_database()
@@ -249,28 +256,6 @@ async def _prepare_request(
249256
elif payload is None:
250257
payload = await self._build_payload(prompt)
251258

252-
if stream_mode:
253-
payload["stream"] = stream_mode
254-
255-
if self.temp != 0.7:
256-
payload["temperature"] = self.temp
257-
258-
# 添加enable_thinking参数(如果不是默认值False)
259-
if not self.enable_thinking:
260-
payload["enable_thinking"] = False
261-
262-
if self.thinking_budget != 4096:
263-
payload["thinking_budget"] = self.thinking_budget
264-
265-
if self.max_tokens:
266-
payload["max_tokens"] = self.max_tokens
267-
268-
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
269-
# payload["max_tokens"] = global_config.model.model_max_output_length
270-
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
271-
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
272-
payload["max_completion_tokens"] = payload.pop("max_tokens")
273-
274259
return {
275260
"policy": policy,
276261
"payload": payload,
@@ -670,18 +655,16 @@ async def _build_payload(self, prompt: str, image_base64: str = None, image_form
670655
if self.temp != 0.7:
671656
payload["temperature"] = self.temp
672657

673-
# 添加enable_thinking参数(如果不是默认值False)
674-
if not self.enable_thinking:
675-
payload["enable_thinking"] = False
658+
# 仅当配置文件中存在参数时,添加对应参数
659+
if self.enable_thinking is not None:
660+
payload["enable_thinking"] = self.enable_thinking
676661

677-
if self.thinking_budget != 4096:
662+
if self.thinking_budget is not None:
678663
payload["thinking_budget"] = self.thinking_budget
679664

680665
if self.max_tokens:
681666
payload["max_tokens"] = self.max_tokens
682667

683-
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
684-
# payload["max_tokens"] = global_config.model.model_max_output_length
685668
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
686669
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
687670
payload["max_completion_tokens"] = payload.pop("max_tokens")

0 commit comments

Comments
 (0)