Skip to content

Commit c20931e

Browse files
fix: 修复 Gemini 模型分页遍历并补充异常分支测试
1 parent 647c798 commit c20931e

3 files changed

Lines changed: 62 additions & 3 deletions

File tree

astrbot/core/provider/sources/gemini_embedding_source.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,10 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
8080

8181
async def get_models(self) -> list[str]:
8282
try:
83-
models = await self.client.models.list()
8483
all_model_ids: list[str] = []
8584
embedding_model_ids: list[str] = []
8685

87-
for model in getattr(models, "page", []):
86+
async for model in await self.client.models.list():
8887
model_id = self._extract_model_id(model)
8988
if not model_id:
9089
continue

tests/test_dashboard.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,25 @@ async def terminate(self):
12431243
type(self).terminate_calls += 1
12441244

12451245

1246+
class _ErrorEmbeddingProvider(EmbeddingProvider):
1247+
terminate_calls = 0
1248+
1249+
async def get_embedding(self, text: str) -> list[float]:
1250+
return [0.1]
1251+
1252+
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
1253+
return [[0.1] for _ in text]
1254+
1255+
def get_dim(self) -> int:
1256+
return 1
1257+
1258+
async def get_models(self) -> list[str]:
1259+
raise RuntimeError("boom")
1260+
1261+
async def terminate(self):
1262+
type(self).terminate_calls += 1
1263+
1264+
12461265
@pytest.mark.asyncio
12471266
async def test_get_embedding_models_success_and_terminate(
12481267
app: Quart,
@@ -1308,3 +1327,37 @@ async def test_get_embedding_models_unsupported_returns_error(
13081327
data = await response.get_json()
13091328
assert data["status"] == "error"
13101329
assert _UnsupportedEmbeddingProvider.terminate_calls == 1
1330+
1331+
1332+
@pytest.mark.asyncio
1333+
async def test_get_embedding_models_runtime_error_returns_error_and_terminate(
1334+
app: Quart,
1335+
authenticated_header: dict,
1336+
monkeypatch,
1337+
):
1338+
from astrbot.core.provider.register import provider_cls_map
1339+
1340+
_ErrorEmbeddingProvider.terminate_calls = 0
1341+
monkeypatch.setitem(
1342+
provider_cls_map,
1343+
"test_embedding_runtime_error",
1344+
SimpleNamespace(cls_type=_ErrorEmbeddingProvider),
1345+
)
1346+
1347+
test_client = app.test_client()
1348+
response = await test_client.post(
1349+
"/api/config/provider/get_embedding_models",
1350+
headers=authenticated_header,
1351+
json={
1352+
"provider_config": {
1353+
"id": "test-embedding-provider",
1354+
"type": "test_embedding_runtime_error",
1355+
}
1356+
},
1357+
)
1358+
1359+
assert response.status_code == 200
1360+
data = await response.get_json()
1361+
assert data["status"] == "error"
1362+
assert "获取嵌入模型列表失败" in data["message"]
1363+
assert _ErrorEmbeddingProvider.terminate_calls == 1

tests/test_gemini_embedding_source.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99

1010
class _FakeModelsPager:
1111
def __init__(self, models) -> None:
12-
self.page = models
12+
self._models = models
13+
14+
def __aiter__(self):
15+
async def _gen():
16+
for model in self._models:
17+
yield model
18+
19+
return _gen()
1320

1421

1522
class _FakeModelsAPI:

0 commit comments

Comments
 (0)