Skip to content

Commit 5054f92

Browse files
author
jiangpeiling
committed
♻️ refactor me model module.
1 parent 6c6f5cf commit 5054f92

File tree

4 files changed

+67
-66
lines changed

4 files changed

+67
-66
lines changed

backend/apps/me_model_managment_app.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from fastapi import APIRouter, Query
44
from fastapi.responses import JSONResponse
55

6-
from consts.model import ModelResponse
76
from services.me_model_management_service import get_me_models_impl
87
from services.model_health_service import check_me_model_connectivity, check_me_connectivity_impl
98

@@ -20,7 +19,6 @@ async def get_me_models(
2019
"""
2120
Get list of models from model engine API
2221
"""
23-
# Call service function to get business logic result
2422
code, message, data = await get_me_models_impl(timeout=timeout, type=type)
2523
return JSONResponse(
2624
status_code=HTTPStatus.OK,
@@ -37,7 +35,6 @@ async def check_me_connectivity(timeout: int = Query(default=2, description="Tim
3735
"""
3836
Health check from model engine API
3937
"""
40-
# Call service function to health check
4138
code, message, data = await check_me_connectivity_impl(timeout)
4239
return JSONResponse(
4340
status_code=HTTPStatus.OK,
@@ -49,8 +46,15 @@ async def check_me_connectivity(timeout: int = Query(default=2, description="Tim
4946
)
5047

5148

52-
@router.get("/model/healthcheck", response_model=ModelResponse)
49+
@router.get("/model/healthcheck")
5350
async def check_me_model_healthcheck(
5451
model_name: str = Query(..., description="Model name to check")
5552
):
56-
return await check_me_model_connectivity(model_name)
53+
code, message, data = await check_me_model_connectivity(model_name)
54+
return JSONResponse(
55+
status_code=HTTPStatus.OK,
56+
content={
57+
"code": code,
58+
"message": message,
59+
"data": data
60+
})

backend/services/model_health_service.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ async def _embedding_dimension_check(
2525
model_type: str,
2626
model_base_url: str,
2727
model_api_key: str):
28-
2928
# Test connectivity based on different model types
3029
if model_type == "embedding":
3130
embedding = await OpenAICompatibleEmbedding(
@@ -50,11 +49,11 @@ async def _embedding_dimension_check(
5049

5150

5251
async def _perform_connectivity_check(
53-
model_name: str,
54-
model_type: str,
55-
model_base_url: str,
56-
model_api_key: str,
57-
embedding_dim: int = 1024
52+
model_name: str,
53+
model_type: str,
54+
model_base_url: str,
55+
model_api_key: str,
56+
embedding_dim: int = 1024
5857
) -> bool:
5958
"""
6059
Perform specific model connectivity check
@@ -183,8 +182,9 @@ async def check_me_model_connectivity(model_name: str):
183182
model_data = next(
184183
(item for item in result if item['id'] == model_name), None)
185184
if not model_data:
186-
return ModelResponse(code=404, message="Specified model not found",
187-
data={"connectivity": False, "message": "Specified model not found", "connect_status": ""})
185+
return HTTPStatus.NOT_FOUND, "Specified model not found", {"connectivity": False,
186+
"message": "Specified model not found",
187+
"connect_status": ""}
188188

189189
model_type = model_data['type']
190190

@@ -205,27 +205,28 @@ async def check_me_model_connectivity(model_name: str):
205205
json=payload
206206
)
207207
else:
208-
return ModelResponse(code=400, message=f"Health check not supported for {model_type} type models",
209-
data={"connectivity": False, "message": f"Health check not supported for {model_type} type models",
210-
"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value})
208+
return HTTPStatus.BAD_REQUEST, f"Health check not supported for {model_type} type models", {
209+
"connectivity": False, "message": f"Health check not supported for {model_type} type models",
210+
"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value}
211211

212212
status_code = api_response.status_code
213213
response_text = api_response.text
214214

215-
if status_code == 200:
215+
if status_code == HTTPStatus.OK:
216216
connect_status = ModelConnectStatusEnum.AVAILABLE.value
217-
return ModelResponse(code=200, message=f"Model {model_name} responded normally",
218-
data={"connectivity": True, "message": f"Model {model_name} responded normally", "connect_status": connect_status})
217+
return HTTPStatus.OK, f"Model {model_name} responded normally", {"connectivity": True,
218+
"message": f"Model {model_name} responded normally",
219+
"connect_status": connect_status}
219220
else:
220221
connect_status = ModelConnectStatusEnum.UNAVAILABLE.value
221-
return ModelResponse(code=status_code, message=f"Model {model_name} response failed",
222-
data={"connectivity": False, "message": f"Model {model_name} response failed: {response_text}",
223-
"connect_status": connect_status})
222+
return status_code, f"Model {model_name} response failed", {"connectivity": False,
223+
"message": f"Model {model_name} response failed: {response_text}",
224+
"connect_status": connect_status}
224225

225226
except Exception as e:
226-
return ModelResponse(code=500, message=f"Unknown error occurred: {str(e)}",
227-
data={"connectivity": False, "message": f"Unknown error occurred: {str(e)}",
228-
"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value})
227+
return HTTPStatus.INTERNAL_SERVER_ERROR, f"Unknown error occurred: {str(e)}", {"connectivity": False,
228+
"message": f"Unknown error occurred: {str(e)}",
229+
"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value}
229230

230231

231232
async def check_me_connectivity_impl(timeout: int):
@@ -240,13 +241,13 @@ async def check_me_connectivity_impl(timeout: int):
240241
headers = {'Authorization': f'Bearer {MODEL_ENGINE_APIKEY}'}
241242

242243
async with aiohttp.ClientSession(
243-
timeout=aiohttp.ClientTimeout(total=timeout),
244-
connector=aiohttp.TCPConnector(ssl=False)
244+
timeout=aiohttp.ClientTimeout(total=timeout),
245+
connector=aiohttp.TCPConnector(ssl=False)
245246
) as session:
246247
try:
247248
async with session.get(
248-
f"{MODEL_ENGINE_HOST}/open/router/v1/models",
249-
headers=headers
249+
f"{MODEL_ENGINE_HOST}/open/router/v1/models",
250+
headers=headers
250251
) as response:
251252
if response.status == HTTPStatus.OK:
252253
return (
@@ -282,10 +283,10 @@ async def check_me_connectivity_impl(timeout: int):
282283
except Exception as e:
283284
return (
284285
HTTPStatus.INTERNAL_SERVER_ERROR,
285-
f"Connection failed: {str(e)}",
286+
f"Unknown error occurred: {str(e)}",
286287
{
287288
"status": "Disconnected",
288-
"desc": f"Connection failed: {str(e)}",
289+
"desc": f"Unknown error occurred: {str(e)}",
289290
"connect_status": ModelConnectStatusEnum.UNAVAILABLE.value
290291
}
291292
)

test/backend/app/test_me_model_managment_app.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,7 @@ async def mock_list_models_not_found(type: str = Query(None)):
220220

221221
# Filter models if type is provided
222222
if type:
223-
filtered_models = [
224-
model for model in all_models if model["type"] == type]
223+
filtered_models = [model for model in all_models if model["type"] == type]
225224

226225
# Return 404 if no models found with this type
227226
if not filtered_models:

test/backend/services/test_model_health_service.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ async def test_perform_connectivity_check_embedding():
130130
# Setup
131131
with mock.patch("backend.services.model_health_service.OpenAICompatibleEmbedding") as mock_embedding:
132132
mock_embedding_instance = mock.MagicMock()
133-
mock_embedding_instance.dimension_check = mock.AsyncMock(return_value=[
134-
1])
133+
mock_embedding_instance.dimension_check = mock.AsyncMock(return_value=[1])
135134
mock_embedding.return_value = mock_embedding_instance
136135

137136
# Execute
@@ -192,8 +191,7 @@ async def test_perform_connectivity_check_llm():
192191
mock_observer.return_value = mock_observer_instance
193192

194193
mock_model_instance = mock.MagicMock()
195-
mock_model_instance.check_connectivity = mock.AsyncMock(
196-
return_value=True)
194+
mock_model_instance.check_connectivity = mock.AsyncMock(return_value=True)
197195
mock_model.return_value = mock_model_instance
198196

199197
# Execute
@@ -224,8 +222,7 @@ async def test_perform_connectivity_check_vlm():
224222
mock_observer.return_value = mock_observer_instance
225223

226224
mock_model_instance = mock.MagicMock()
227-
mock_model_instance.check_connectivity = mock.AsyncMock(
228-
return_value=True)
225+
mock_model_instance.check_connectivity = mock.AsyncMock(return_value=True)
229226
mock_model.return_value = mock_model_instance
230227

231228
# Execute
@@ -520,12 +517,12 @@ async def test_check_me_model_connectivity_llm_success():
520517
mock_client_instance.post.return_value = mock_test_response
521518

522519
# Execute
523-
response = await check_me_model_connectivity("gpt-4")
520+
code, _, data = await check_me_model_connectivity("gpt-4")
524521

525522
# Assert
526-
assert response.code == 200
527-
assert response.data["connectivity"] is True
528-
assert response.data["connect_status"] == "available"
523+
assert code == 200
524+
assert data["connectivity"] is True
525+
assert data["connect_status"] == "available"
529526

530527
# Verify API calls
531528
mock_client_instance.get.assert_called_once_with(
@@ -573,12 +570,12 @@ async def test_check_me_model_connectivity_embedding_success():
573570
mock_client_instance.post.return_value = mock_test_response
574571

575572
# Execute
576-
response = await check_me_model_connectivity("text-embedding-ada-002")
573+
code, _, data = await check_me_model_connectivity("text-embedding-ada-002")
577574

578575
# Assert
579-
assert response.code == 200
580-
assert response.data["connectivity"] is True
581-
assert response.data["connect_status"] == "available"
576+
assert code == 200
577+
assert data["connectivity"] is True
578+
assert data["connect_status"] == "available"
582579

583580
# Verify API calls
584581
mock_client_instance.get.assert_called_once_with(
@@ -614,12 +611,12 @@ async def test_check_me_model_connectivity_model_not_found():
614611
mock_client_instance.get.return_value = mock_response_obj
615612

616613
# Execute
617-
response = await check_me_model_connectivity("nonexistent-model")
614+
code, _, data = await check_me_model_connectivity("nonexistent-model")
618615

619616
# Assert
620-
assert response.code == 404
621-
assert response.data["connectivity"] is False
622-
assert response.data["message"] == "Specified model not found"
617+
assert code == 404
618+
assert data["connectivity"] is False
619+
assert data["message"] == "Specified model not found"
623620

624621

625622
@pytest.mark.asyncio
@@ -647,13 +644,13 @@ async def test_check_me_model_connectivity_unsupported_type():
647644
mock_client_instance.get.return_value = mock_response_obj
648645

649646
# Execute
650-
response = await check_me_model_connectivity("unsupported-model")
647+
code, _, data = await check_me_model_connectivity("unsupported-model")
651648

652649
# Assert
653-
assert response.code == 400
654-
assert response.data["connectivity"] is False
655-
assert response.data["connect_status"] == "unavailable"
656-
assert "Health check not supported" in response.data["message"]
650+
assert code == 400
651+
assert data["connectivity"] is False
652+
assert data["connect_status"] == "unavailable"
653+
assert "Health check not supported" in data["message"]
657654

658655

659656
@pytest.mark.asyncio
@@ -687,13 +684,13 @@ async def test_check_me_model_connectivity_api_error():
687684
mock_client_instance.post.return_value = mock_test_response
688685

689686
# Execute
690-
response = await check_me_model_connectivity("gpt-4")
687+
code, _, data = await check_me_model_connectivity("gpt-4")
691688

692689
# Assert
693-
assert response.code == 500
694-
assert response.data["connectivity"] is False
695-
assert response.data["connect_status"] == "unavailable"
696-
assert "response failed" in response.data["message"]
690+
assert code == 500
691+
assert data["connectivity"] is False
692+
assert data["connect_status"] == "unavailable"
693+
assert "response failed" in data["message"]
697694

698695

699696
@pytest.mark.asyncio
@@ -711,13 +708,13 @@ async def test_check_me_model_connectivity_exception():
711708
"Connection error")
712709

713710
# Execute
714-
response = await check_me_model_connectivity("gpt-4")
711+
code, _, data = await check_me_model_connectivity("gpt-4")
715712

716713
# Assert
717-
assert response.code == 500
718-
assert response.data["connectivity"] is False
719-
assert response.data["connect_status"] == "unavailable"
720-
assert "Unknown error" in response.data["message"]
714+
assert code == 500
715+
assert data["connectivity"] is False
716+
assert data["connect_status"] == "unavailable"
717+
assert "Unknown error" in data["message"]
721718

722719

723720
@pytest.mark.asyncio

0 commit comments

Comments
 (0)