Skip to content

Commit 647c798

Browse files
feat: 支持 Gemini 嵌入模型列表发现
1 parent 8568f3e commit 647c798

2 files changed

Lines changed: 152 additions & 1 deletion

File tree

astrbot/core/provider/sources/gemini_embedding_source.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast
1+
from typing import Any, cast
22

33
from google import genai
44
from google.genai import types
@@ -78,10 +78,58 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
7878
except APIError as e:
7979
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
8080

81+
async def get_models(self) -> list[str]:
82+
try:
83+
models = await self.client.models.list()
84+
all_model_ids: list[str] = []
85+
embedding_model_ids: list[str] = []
86+
87+
for model in getattr(models, "page", []):
88+
model_id = self._extract_model_id(model)
89+
if not model_id:
90+
continue
91+
all_model_ids.append(model_id)
92+
if self._supports_embedding(model, model_id):
93+
embedding_model_ids.append(model_id)
94+
95+
all_model_ids = sorted(dict.fromkeys(all_model_ids))
96+
embedding_model_ids = sorted(dict.fromkeys(embedding_model_ids))
97+
98+
return embedding_model_ids or all_model_ids
99+
except Exception as e:
100+
raise Exception(f"获取 Gemini 嵌入模型列表失败: {e!s}") from e
101+
81102
def get_dim(self) -> int:
82103
"""获取向量的维度"""
83104
return int(self.provider_config.get("embedding_dimensions", 768))
84105

85106
async def terminate(self):
86107
if self.client:
87108
await self.client.aclose()
109+
110+
@staticmethod
111+
def _extract_model_id(model: Any) -> str:
112+
model_name = getattr(model, "name", "") or getattr(model, "model", "")
113+
if not model_name:
114+
return ""
115+
return str(model_name).removeprefix("models/")
116+
117+
@classmethod
118+
def _supports_embedding(cls, model: Any, model_id: str) -> bool:
119+
supported_actions = getattr(model, "supported_actions", None) or getattr(
120+
model, "supported_generation_methods", []
121+
)
122+
if isinstance(supported_actions, list):
123+
normalized_actions = {
124+
str(action).lower().replace("_", "").replace("-", "")
125+
for action in supported_actions
126+
}
127+
if "embedcontent" in normalized_actions:
128+
return True
129+
130+
return cls._looks_like_embedding_model(model_id)
131+
132+
@staticmethod
133+
def _looks_like_embedding_model(model_id: str) -> bool:
134+
normalized_model_id = model_id.lower()
135+
return "embedding" in normalized_model_id or "embed" in normalized_model_id
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from types import SimpleNamespace
2+
3+
import pytest
4+
5+
from astrbot.core.provider.sources.gemini_embedding_source import (
6+
GeminiEmbeddingProvider,
7+
)
8+
9+
10+
class _FakeModelsPager:
11+
def __init__(self, models) -> None:
12+
self.page = models
13+
14+
15+
class _FakeModelsAPI:
16+
def __init__(self, models) -> None:
17+
self._models = models
18+
19+
async def list(self):
20+
return _FakeModelsPager(self._models)
21+
22+
23+
class _FakeClient:
24+
def __init__(self, models) -> None:
25+
self.models = _FakeModelsAPI(models)
26+
self.closed = False
27+
28+
async def aclose(self):
29+
self.closed = True
30+
31+
32+
def _make_provider() -> GeminiEmbeddingProvider:
33+
provider_config = {
34+
"id": "test-gemini-embedding",
35+
"type": "gemini_embedding",
36+
"embedding_api_key": "test-key",
37+
"embedding_api_base": "https://generativelanguage.googleapis.com",
38+
"embedding_model": "gemini-embedding-exp-03-07",
39+
}
40+
return GeminiEmbeddingProvider(
41+
provider_config=provider_config,
42+
provider_settings={},
43+
)
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_gemini_embedding_get_models_prefers_embedcontent_capability():
48+
provider = _make_provider()
49+
try:
50+
provider.client = _FakeClient(
51+
[
52+
SimpleNamespace(
53+
name="models/gemini-2.5-flash",
54+
supported_actions=["generateContent"],
55+
),
56+
SimpleNamespace(
57+
name="models/gemini-embedding-001",
58+
supported_actions=["embedContent"],
59+
),
60+
SimpleNamespace(
61+
name="models/text-embedding-preview",
62+
supported_generation_methods=["embedContent"],
63+
),
64+
]
65+
)
66+
models = await provider.get_models()
67+
assert models == ["gemini-embedding-001", "text-embedding-preview"]
68+
finally:
69+
await provider.terminate()
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_gemini_embedding_get_models_falls_back_to_name_matching():
74+
provider = _make_provider()
75+
try:
76+
provider.client = _FakeClient(
77+
[
78+
SimpleNamespace(name="models/chat-pro"),
79+
SimpleNamespace(name="models/embed-lite"),
80+
SimpleNamespace(name="models/text-embedding-preview"),
81+
]
82+
)
83+
models = await provider.get_models()
84+
assert models == ["embed-lite", "text-embedding-preview"]
85+
finally:
86+
await provider.terminate()
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_gemini_embedding_get_models_falls_back_to_all_when_no_match():
91+
provider = _make_provider()
92+
try:
93+
provider.client = _FakeClient(
94+
[
95+
SimpleNamespace(name="models/gemini-2.5-pro"),
96+
SimpleNamespace(name="models/gemini-2.5-flash"),
97+
SimpleNamespace(name="models/gemini-2.5-pro"),
98+
]
99+
)
100+
models = await provider.get_models()
101+
assert models == ["gemini-2.5-flash", "gemini-2.5-pro"]
102+
finally:
103+
await provider.terminate()

0 commit comments

Comments
 (0)