Skip to content

Commit b95674a

Browse files
committed
add more tests
1 parent 6369c61 commit b95674a

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

tests/unit_tests/test_core.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,39 @@ def test_llm_factory_openai(self, mock_settings, mock_chat_openai):
177177
llm_factory("default")
178178
mock_chat_openai.assert_called_once()
179179

180+
@patch("core.llm.AzureChatOpenAI")
181+
@patch("core.llm.settings")
182+
def test_llm_factory_azure(self, mock_settings, mock_azure):
183+
mock_settings.llm.profiles = {
184+
"azure": MagicMock(
185+
provider="azure",
186+
model="gpt-4",
187+
api_key="key",
188+
api_base="endpoint",
189+
api_version="2023-05-15"
190+
)
191+
}
192+
llm_factory("azure")
193+
mock_azure.assert_called_once()
194+
195+
@patch("core.llm.ChatAnthropic")
196+
@patch("core.llm.settings")
197+
def test_llm_factory_anthropic(self, mock_settings, mock_anthropic):
198+
mock_settings.llm.profiles = {
199+
"claude": MagicMock(provider="anthropic", model="claude-3", api_key="key")
200+
}
201+
llm_factory("claude")
202+
mock_anthropic.assert_called_once()
203+
204+
@patch("core.llm.ChatGoogleGenerativeAI")
205+
@patch("core.llm.settings")
206+
def test_llm_factory_google(self, mock_settings, mock_google):
207+
mock_settings.llm.profiles = {
208+
"gemini": MagicMock(provider="gemini", model="gemini-pro", credentials="creds")
209+
}
210+
llm_factory("gemini")
211+
mock_google.assert_called_once()
212+
180213
@patch("core.llm.settings")
181214
def test_llm_factory_unknown_provider(self, mock_settings):
182215
mock_settings.llm.profiles = {

tests/unit_tests/test_graph.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ def test_route_workflow(self):
173173
route_workflow({"next_worker": "common_agent", "messages": [dummy_msg]}), "common_agent"
174174
)
175175

176+
def test_route_workflow_unknown(self):
177+
"""Test routing with an unknown worker."""
178+
dummy_msg = HumanMessage(content="test")
179+
self.assertEqual(
180+
route_workflow({"next_worker": "unknown", "messages": [dummy_msg]}), "unknown"
181+
)
182+
176183
def test_route_after_worker(self):
177184
"""Test routing logic after a worker agent finishes."""
178185
# Tool calls present -> route to tools
@@ -241,6 +248,21 @@ def test_detect_tool_loop(self):
241248
}
242249
self.assertTrue(detect_tool_loop(state_alternating))
243250

251+
def test_detect_tool_loop_different_args(self):
252+
"""Test that same tool with different args is not a loop."""
253+
call_a = {"name": "tool_a", "args": {"x": 1}, "id": "1"}
254+
call_b = {"name": "tool_a", "args": {"x": 2}, "id": "2"}
255+
state = {
256+
"messages": [
257+
AIMessage(content="", tool_calls=[call_a]),
258+
ToolMessage(content="res", tool_call_id="1", name="tool_a"),
259+
AIMessage(content="", tool_calls=[call_b]),
260+
ToolMessage(content="res", tool_call_id="2", name="tool_a"),
261+
AIMessage(content="", tool_calls=[call_a]),
262+
]
263+
}
264+
self.assertFalse(detect_tool_loop(state))
265+
244266
def test_route_after_worker_loop_detection(self):
245267
"""Test that route_after_worker catches loops."""
246268
tool_call = {"name": "tool_a", "args": {"x": 1}, "id": "1"}
@@ -578,6 +600,17 @@ def test_update_step_status(self):
578600
self.assertEqual(steps_status[0], "completed")
579601
self.assertEqual(steps_status[1], "skipped")
580602

603+
def test_update_step_status_invalid_id(self):
604+
"""Test update_step_status with an ID that doesn't exist."""
605+
steps_status = ["pending"]
606+
decision = MagicMock()
607+
decision.current_step_id = 99
608+
decision.current_step_status = "completed"
609+
decision.skipped_step_ids = []
610+
611+
update_step_status(decision, steps_status)
612+
self.assertEqual(steps_status[0], "pending")
613+
581614
def test_check_completion(self):
582615
decision = MagicMock()
583616
decision.next_worker = "supervisor"

tests/unit_tests/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def test_context_formatting(self):
4343
res = StatusGenerator.parse_tool_call("test", {"items": [1, 2, 3]})
4444
self.assertIn(": 3 items", res)
4545

46+
def test_parse_tool_call_no_args(self):
47+
res = StatusGenerator.parse_tool_call("simple_tool", {})
48+
self.assertIn("Processing", res)
49+
self.assertIn("Simple", res)
50+
4651
class TestLoggerUtils(unittest.TestCase):
4752
def test_parse_size(self):
4853
self.assertEqual(_parse_size("10B"), 10)
@@ -59,6 +64,9 @@ def test_parse_size_case_insensitive(self):
5964
def test_parse_size_float(self):
6065
self.assertEqual(_parse_size("1.5 KB"), int(1.5 * 1024))
6166

67+
def test_parse_size_no_space(self):
68+
self.assertEqual(_parse_size("1KB"), 1024)
69+
6270
@patch("src.utils.logger.logging")
6371
@patch("src.utils.logger.RotatingFileHandler")
6472
@patch("src.utils.logger.settings")

0 commit comments

Comments
 (0)