Skip to content

Commit 2285ea2

Browse files
committed
Add comprehensive unit tests for the circuit breaker plugin and remove an unnecessary configuration lookup.
Signed-off-by: Keval Mahajan [email protected]
1 parent b99cb2b commit 2285ea2

File tree

2 files changed

+348
-1
lines changed

2 files changed

+348
-1
lines changed

plugins/circuit_breaker/circuit_breaker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo
164164
"""
165165
tool = payload.name
166166
st = _get_state(tool)
167-
cfg = _cfg_for(self._cfg, tool)
168167
now = _now()
169168

170169
# Check if cooldown has elapsed - transition to half-open state
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# -*- coding: utf-8 -*-
2+
"""Tests for Circuit Breaker Plugin.
3+
4+
Verifies all functionality:
5+
1. Closed state - allows requests
6+
2. Opens on consecutive failures
7+
3. Opens on error rate threshold
8+
4. Blocks requests when open
9+
5. Half-open state - allows probe request
10+
6. Closes on successful probe
11+
7. Reopens on failed probe
12+
8. Timeout failures trigger circuit breaker
13+
9. retry_after_seconds calculation
14+
10. Per-tool configuration overrides
15+
"""
16+
17+
import asyncio
18+
import pytest
19+
import time
20+
from unittest.mock import MagicMock, patch
21+
22+
# Import the circuit breaker components
23+
from plugins.circuit_breaker.circuit_breaker import (
24+
CircuitBreakerPlugin,
25+
CircuitBreakerConfig,
26+
_ToolState,
27+
_STATE,
28+
_get_state,
29+
_cfg_for,
30+
_is_error,
31+
)
32+
from mcpgateway.plugins.framework import (
33+
PluginConfig,
34+
PluginContext,
35+
ToolPreInvokePayload,
36+
ToolPostInvokePayload,
37+
GlobalContext,
38+
)
39+
40+
41+
@pytest.fixture(autouse=True)
42+
def clear_state():
43+
"""Clear circuit breaker state before each test."""
44+
_STATE.clear()
45+
yield
46+
_STATE.clear()
47+
48+
49+
@pytest.fixture
50+
def plugin():
51+
"""Create a circuit breaker plugin with test configuration."""
52+
config = PluginConfig(
53+
id="test-cb",
54+
kind="circuit_breaker",
55+
name="Test Circuit Breaker",
56+
enabled=True,
57+
order=0,
58+
config={
59+
"error_rate_threshold": 0.5,
60+
"window_seconds": 60,
61+
"min_calls": 3,
62+
"consecutive_failure_threshold": 3,
63+
"cooldown_seconds": 30,
64+
},
65+
)
66+
return CircuitBreakerPlugin(config)
67+
68+
69+
@pytest.fixture
70+
def context():
71+
"""Create a plugin context for testing."""
72+
global_ctx = GlobalContext(request_id="test-request-123")
73+
return PluginContext(plugin_id="test-cb", global_context=global_ctx)
74+
75+
76+
class TestCircuitBreakerClosedState:
77+
"""Test circuit breaker in closed state."""
78+
79+
@pytest.mark.asyncio
80+
async def test_allows_requests_when_closed(self, plugin, context):
81+
"""Closed circuit should allow requests through."""
82+
payload = ToolPreInvokePayload(name="test_tool", args={})
83+
result = await plugin.tool_pre_invoke(payload, context)
84+
85+
assert result.continue_processing is True
86+
assert result.violation is None
87+
88+
@pytest.mark.asyncio
89+
async def test_records_call_timestamp(self, plugin, context):
90+
"""Pre-invoke should record call timestamp in context."""
91+
payload = ToolPreInvokePayload(name="test_tool", args={})
92+
await plugin.tool_pre_invoke(payload, context)
93+
94+
call_time = context.get_state("cb_call_time")
95+
assert call_time is not None
96+
assert abs(call_time - time.time()) < 1 # Within 1 second
97+
98+
99+
class TestCircuitBreakerOpening:
100+
"""Test circuit breaker opening conditions."""
101+
102+
@pytest.mark.asyncio
103+
async def test_opens_on_consecutive_failures(self, plugin, context):
104+
"""Circuit should open after consecutive_failure_threshold failures."""
105+
tool = "test_tool"
106+
107+
# Simulate 3 consecutive failures
108+
for _ in range(3):
109+
pre_payload = ToolPreInvokePayload(name=tool, args={})
110+
await plugin.tool_pre_invoke(pre_payload, context)
111+
112+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": True})
113+
result = await plugin.tool_post_invoke(post_payload, context)
114+
115+
# Check circuit is now open
116+
assert result.metadata["circuit_open_until"] > 0
117+
assert result.metadata["circuit_consecutive_failures"] == 3
118+
119+
@pytest.mark.asyncio
120+
async def test_opens_on_error_rate_threshold(self, plugin, context):
121+
"""Circuit should open when error rate exceeds threshold."""
122+
tool = "test_tool"
123+
124+
# Simulate 3 calls (min_calls): 2 failures, 1 success = 66% error rate > 50% threshold
125+
for i in range(3):
126+
pre_payload = ToolPreInvokePayload(name=tool, args={})
127+
await plugin.tool_pre_invoke(pre_payload, context)
128+
129+
is_error = i < 2 # First 2 are failures
130+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": is_error})
131+
result = await plugin.tool_post_invoke(post_payload, context)
132+
133+
# Check circuit is now open
134+
assert result.metadata["circuit_open_until"] > 0
135+
assert result.metadata["circuit_failure_rate"] >= 0.5
136+
137+
138+
class TestCircuitBreakerOpenState:
139+
"""Test circuit breaker in open state."""
140+
141+
@pytest.mark.asyncio
142+
async def test_blocks_requests_when_open(self, plugin, context):
143+
"""Open circuit should block requests."""
144+
tool = "test_tool"
145+
st = _get_state(tool)
146+
st.open_until = time.time() + 30 # Open for 30 seconds
147+
148+
payload = ToolPreInvokePayload(name=tool, args={})
149+
result = await plugin.tool_pre_invoke(payload, context)
150+
151+
assert result.continue_processing is False
152+
assert result.violation is not None
153+
assert result.violation.code == "CIRCUIT_OPEN"
154+
155+
@pytest.mark.asyncio
156+
async def test_returns_retry_after_seconds(self, plugin, context):
157+
"""Open circuit should return retry_after_seconds in violation details."""
158+
tool = "test_tool"
159+
st = _get_state(tool)
160+
st.open_until = time.time() + 30 # Open for 30 seconds
161+
162+
payload = ToolPreInvokePayload(name=tool, args={})
163+
result = await plugin.tool_pre_invoke(payload, context)
164+
165+
assert result.violation is not None
166+
assert "retry_after_seconds" in result.violation.details
167+
assert result.violation.details["retry_after_seconds"] > 0
168+
assert result.violation.details["retry_after_seconds"] <= 30
169+
170+
171+
class TestCircuitBreakerHalfOpenState:
172+
"""Test circuit breaker half-open state."""
173+
174+
@pytest.mark.asyncio
175+
async def test_transitions_to_half_open_after_cooldown(self, plugin, context):
176+
"""Circuit should transition to half-open after cooldown."""
177+
tool = "test_tool"
178+
st = _get_state(tool)
179+
st.open_until = time.time() - 1 # Cooldown elapsed
180+
st.consecutive_failures = 5
181+
182+
payload = ToolPreInvokePayload(name=tool, args={})
183+
result = await plugin.tool_pre_invoke(payload, context)
184+
185+
# Should allow request through (half-open)
186+
assert result.continue_processing is True
187+
assert st.half_open is True
188+
assert context.get_state("cb_half_open_test") is True
189+
190+
@pytest.mark.asyncio
191+
async def test_closes_on_successful_probe(self, plugin, context):
192+
"""Half-open circuit should close on successful probe."""
193+
tool = "test_tool"
194+
st = _get_state(tool)
195+
st.half_open = True
196+
st.consecutive_failures = 5
197+
198+
# Set context for half-open test
199+
context.set_state("cb_half_open_test", True)
200+
context.set_state("cb_call_time", time.time())
201+
202+
# Successful probe
203+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False})
204+
result = await plugin.tool_post_invoke(post_payload, context)
205+
206+
# Circuit should be fully closed
207+
assert st.half_open is False
208+
assert st.consecutive_failures == 0
209+
assert result.metadata["circuit_open_until"] == 0.0
210+
211+
@pytest.mark.asyncio
212+
async def test_reopens_on_failed_probe(self, plugin, context):
213+
"""Half-open circuit should reopen immediately on failed probe."""
214+
tool = "test_tool"
215+
st = _get_state(tool)
216+
st.half_open = True
217+
st.consecutive_failures = 5
218+
219+
# Set context for half-open test
220+
context.set_state("cb_half_open_test", True)
221+
context.set_state("cb_call_time", time.time())
222+
223+
# Failed probe
224+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": True})
225+
with patch("mcpgateway.services.metrics.circuit_breaker_open_counter") as mock_counter:
226+
mock_counter.labels.return_value.inc = MagicMock()
227+
result = await plugin.tool_post_invoke(post_payload, context)
228+
229+
# Circuit should be reopened
230+
assert st.half_open is False
231+
assert result.metadata["circuit_open_until"] > time.time()
232+
233+
234+
class TestTimeoutIntegration:
235+
"""Test timeout integration with circuit breaker."""
236+
237+
@pytest.mark.asyncio
238+
async def test_timeout_counted_as_failure(self, plugin, context):
239+
"""Timeout flag should be counted as failure."""
240+
tool = "test_tool"
241+
242+
# Set timeout flag (as tool_service would do)
243+
context.set_state("cb_timeout_failure", True)
244+
context.set_state("cb_call_time", time.time())
245+
246+
# Post-invoke with a technically successful result but timeout flag set
247+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False})
248+
result = await plugin.tool_post_invoke(post_payload, context)
249+
250+
# Should count as failure
251+
assert result.metadata["circuit_failures_in_window"] == 1
252+
assert result.metadata["circuit_consecutive_failures"] == 1
253+
254+
255+
class TestPerToolOverrides:
256+
"""Test per-tool configuration overrides."""
257+
258+
@pytest.mark.asyncio
259+
async def test_tool_override_applied(self, context):
260+
"""Per-tool overrides should be applied correctly."""
261+
config = PluginConfig(
262+
id="test-cb",
263+
kind="circuit_breaker",
264+
name="Test Circuit Breaker",
265+
enabled=True,
266+
order=0,
267+
config={
268+
"consecutive_failure_threshold": 5,
269+
"tool_overrides": {
270+
"critical_tool": {"consecutive_failure_threshold": 10}
271+
},
272+
},
273+
)
274+
plugin = CircuitBreakerPlugin(config)
275+
276+
# Simulate 5 failures on critical_tool (should NOT open - needs 10)
277+
for _ in range(5):
278+
pre_payload = ToolPreInvokePayload(name="critical_tool", args={})
279+
await plugin.tool_pre_invoke(pre_payload, context)
280+
281+
post_payload = ToolPostInvokePayload(name="critical_tool", result={"is_error": True})
282+
result = await plugin.tool_post_invoke(post_payload, context)
283+
284+
# Circuit should still be closed (needs 10 failures)
285+
assert result.metadata["circuit_open_until"] == 0.0
286+
assert result.metadata["circuit_consecutive_failures"] == 5
287+
288+
289+
class TestHelperFunctions:
290+
"""Test helper functions."""
291+
292+
def test_is_error_with_dict(self):
293+
"""_is_error should detect error in dict result."""
294+
assert _is_error({"is_error": True}) is True
295+
assert _is_error({"is_error": False}) is False
296+
assert _is_error({"success": True}) is False
297+
298+
def test_is_error_with_object(self):
299+
"""_is_error should detect error in object result."""
300+
class MockResult:
301+
is_error = True
302+
303+
assert _is_error(MockResult()) is True
304+
MockResult.is_error = False
305+
assert _is_error(MockResult()) is False
306+
307+
def test_cfg_for_with_override(self):
308+
"""_cfg_for should merge tool overrides."""
309+
base_cfg = CircuitBreakerConfig(
310+
consecutive_failure_threshold=5,
311+
tool_overrides={"special_tool": {"consecutive_failure_threshold": 10}},
312+
)
313+
314+
merged = _cfg_for(base_cfg, "special_tool")
315+
assert merged.consecutive_failure_threshold == 10
316+
317+
default = _cfg_for(base_cfg, "regular_tool")
318+
assert default.consecutive_failure_threshold == 5
319+
320+
321+
class TestWindowEviction:
322+
"""Test time window eviction logic."""
323+
324+
@pytest.mark.asyncio
325+
async def test_old_entries_evicted(self, plugin, context):
326+
"""Old call/failure entries should be evicted after window expires."""
327+
tool = "test_tool"
328+
st = _get_state(tool)
329+
330+
# Add old entries (outside window)
331+
old_time = time.time() - 120 # 2 minutes ago
332+
st.calls.append(old_time)
333+
st.failures.append(old_time)
334+
335+
# Make a new call
336+
pre_payload = ToolPreInvokePayload(name=tool, args={})
337+
await plugin.tool_pre_invoke(pre_payload, context)
338+
339+
post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False})
340+
result = await plugin.tool_post_invoke(post_payload, context)
341+
342+
# Old entries should be evicted
343+
assert result.metadata["circuit_calls_in_window"] == 1
344+
assert result.metadata["circuit_failures_in_window"] == 0
345+
346+
347+
if __name__ == "__main__":
348+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)