diff --git a/mcpgateway/services/metrics.py b/mcpgateway/services/metrics.py index 0b4d51cba1..3b35ddaf82 100644 --- a/mcpgateway/services/metrics.py +++ b/mcpgateway/services/metrics.py @@ -44,12 +44,26 @@ # Third-Party from fastapi import Response, status -from prometheus_client import Gauge, REGISTRY +from prometheus_client import Counter, Gauge, REGISTRY from prometheus_fastapi_instrumentator import Instrumentator # First-Party from mcpgateway.config import settings +# Global Metrics +# Exposed for import by services/plugins to increment counters +tool_timeout_counter = Counter( + "tool_timeout_total", + "Total number of tool invocation timeouts", + ["tool_name"], +) + +circuit_breaker_open_counter = Counter( + "circuit_breaker_open_total", + "Total number of times circuit breaker opened", + ["tool_name"], +) + def setup_metrics(app): """ diff --git a/mcpgateway/services/tool_service.py b/mcpgateway/services/tool_service.py index 85a3dd1b16..2d63842325 100644 --- a/mcpgateway/services/tool_service.py +++ b/mcpgateway/services/tool_service.py @@ -15,6 +15,7 @@ """ # Standard +import asyncio import base64 import binascii from datetime import datetime, timezone @@ -387,6 +388,15 @@ class ToolInvocationError(ToolError): """ +class ToolTimeoutError(ToolInvocationError): + """Raised when tool invocation times out. + + This subclass is used to distinguish timeout errors from other invocation errors. + Timeout handlers call tool_post_invoke before raising this, so the generic exception + handler should skip calling post_invoke again to avoid double-counting failures. + """ + + class ToolService: """Service for managing and invoking tools. @@ -2497,6 +2507,7 @@ async def invoke_tool( Raises: ToolNotFoundError: If tool not found or access denied. ToolInvocationError: If invocation fails. + ToolTimeoutError: If tool invocation times out. PluginViolationError: If plugin blocks tool invocation. PluginError: If encounters issue with plugin @@ -2613,6 +2624,11 @@ async def invoke_tool( tool_oauth_config = tool_payload.get("oauth_config") tool_gateway_id = tool_payload.get("gateway_id") + # Get effective timeout: per-tool timeout_ms (in seconds) or global fallback + # timeout_ms is stored in milliseconds, convert to seconds + tool_timeout_ms = tool_payload.get("timeout_ms") + effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else settings.tool_timeout + # Save gateway existence as local boolean BEFORE db.close() # to avoid checking ORM object truthiness after session is closed has_gateway = gateway_payload is not None @@ -2851,10 +2867,53 @@ async def invoke_tool( # Use the tool's request_type rather than defaulting to POST (using local variable) method = tool_request_type.upper() if tool_request_type else "POST" - if method == "GET": - response = await self._http_client.get(final_url, params=payload, headers=headers) - else: - response = await self._http_client.request(method, final_url, json=payload, headers=headers) + rest_start_time = time.time() + try: + if method == "GET": + response = await asyncio.wait_for(self._http_client.get(final_url, params=payload, headers=headers), timeout=effective_timeout) + else: + response = await asyncio.wait_for(self._http_client.request(method, final_url, json=payload, headers=headers), timeout=effective_timeout) + except (asyncio.TimeoutError, httpx.TimeoutException): + rest_elapsed_ms = (time.time() - rest_start_time) * 1000 + structured_logger.log( + level="WARNING", + message=f"REST tool invocation timed out: {tool_name_computed}", + component="tool_service", + correlation_id=get_correlation_id(), + duration_ms=rest_elapsed_ms, + metadata={"event": "tool_timeout", "tool_name": tool_name_computed, "timeout_seconds": effective_timeout}, + ) + + # Manually trigger circuit breaker (or other plugins) on timeout + try: + # First-Party + from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel + + tool_timeout_counter.labels(tool_name=name).inc() + except Exception as exc: + logger.debug( + "Failed to increment tool_timeout_counter for %s: %s", + name, + exc, + exc_info=True, + ) + + if self._plugin_manager: + if context_table: + for ctx in context_table.values(): + ctx.set_state("cb_timeout_failure", True) + + if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): + timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True) + await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=False, + ) + + raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") response.raise_for_status() # Handle 204 No Content responses that have no body @@ -2979,11 +3038,16 @@ def get_httpx_client_factory( ctx = create_ssl_context(gateway_ca_cert) else: ctx = None + + # Use effective_timeout for read operations if not explicitly overridden by caller + # This ensures the underlying client waits at least as long as the tool configuration requires + factory_timeout = timeout if timeout else get_http_timeout(read_timeout=effective_timeout) + return httpx.AsyncClient( verify=ctx if ctx else get_default_verify(), follow_redirects=True, headers=headers, - timeout=timeout if timeout else get_http_timeout(), + timeout=factory_timeout, auth=auth, limits=httpx.Limits( max_connections=settings.httpx_max_connections, @@ -3003,7 +3067,10 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): ToolResult: Result of tool call Raises: + ToolInvocationError: If the tool invocation fails during execution. + ToolTimeoutError: If the tool invocation times out. BaseException: On connection or communication errors + """ # Get correlation ID for distributed tracing correlation_id = get_correlation_id() @@ -3048,7 +3115,7 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): user_identity=app_user_email, gateway_id=gateway_id_str, ) as pooled: - tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data) + tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout) else: # Non-pooled path: safe to add per-request headers if correlation_id and headers: @@ -3057,7 +3124,7 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): async with sse_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as streams: async with ClientSession(*streams) as session: await session.initialize() - tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data) + tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout) # Log successful MCP call mcp_duration_ms = (time.time() - mcp_start_time) * 1000 @@ -3071,6 +3138,48 @@ async def connect_to_sse_server(server_url: str, headers: dict = headers): ) return tool_call_result + except (asyncio.TimeoutError, httpx.TimeoutException): + # Handle timeout specifically - log and raise ToolInvocationError + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="WARNING", + message=f"MCP SSE tool invocation timed out: {tool_name_original}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "sse", "timeout_seconds": effective_timeout}, + ) + + # Manually trigger circuit breaker (or other plugins) on timeout + try: + # First-Party + from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel + + tool_timeout_counter.labels(tool_name=name).inc() + except Exception as exc: + logger.debug( + "Failed to increment tool_timeout_counter for %s: %s", + name, + exc, + exc_info=True, + ) + + if self._plugin_manager: + if context_table: + for ctx in context_table.values(): + ctx.set_state("cb_timeout_failure", True) + + if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): + timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True) + await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=False, + ) + + raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") except BaseException as e: # Extract root cause from ExceptionGroup (Python 3.11+) # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup @@ -3104,6 +3213,8 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head ToolResult: Result of tool call Raises: + ToolInvocationError: If the tool invocation fails during execution. + ToolTimeoutError: If the tool invocation times out. BaseException: On connection or communication errors """ # Get correlation ID for distributed tracing @@ -3149,7 +3260,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head user_identity=app_user_email, gateway_id=gateway_id_str, ) as pooled: - tool_call_result = await pooled.session.call_tool(tool_name_original, arguments, meta=meta_data) + tool_call_result = await asyncio.wait_for(pooled.session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout) else: # Non-pooled path: safe to add per-request headers if correlation_id and headers: @@ -3158,7 +3269,7 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head async with streamablehttp_client(url=server_url, headers=headers, httpx_client_factory=get_httpx_client_factory) as (read_stream, write_stream, _get_session_id): async with ClientSession(read_stream, write_stream) as session: await session.initialize() - tool_call_result = await session.call_tool(tool_name_original, arguments, meta=meta_data) + tool_call_result = await asyncio.wait_for(session.call_tool(tool_name_original, arguments, meta=meta_data), timeout=effective_timeout) # Log successful MCP call mcp_duration_ms = (time.time() - mcp_start_time) * 1000 @@ -3172,6 +3283,48 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head ) return tool_call_result + except (asyncio.TimeoutError, httpx.TimeoutException): + # Handle timeout specifically - log and raise ToolInvocationError + mcp_duration_ms = (time.time() - mcp_start_time) * 1000 + structured_logger.log( + level="WARNING", + message=f"MCP StreamableHTTP tool invocation timed out: {tool_name_original}", + component="tool_service", + correlation_id=correlation_id, + duration_ms=mcp_duration_ms, + metadata={"event": "tool_timeout", "tool_name": tool_name_original, "tool_id": tool_id, "transport": "streamablehttp", "timeout_seconds": effective_timeout}, + ) + + # Manually trigger circuit breaker (or other plugins) on timeout + try: + # First-Party + from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel + + tool_timeout_counter.labels(tool_name=name).inc() + except Exception as exc: + logger.debug( + "Failed to increment tool_timeout_counter for %s: %s", + name, + exc, + exc_info=True, + ) + + if self._plugin_manager: + if context_table: + for ctx in context_table.values(): + ctx.set_state("cb_timeout_failure", True) + + if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): + timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True) + await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=False, + ) + + raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") except BaseException as e: # Extract root cause from ExceptionGroup (Python 3.11+) # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup @@ -3292,9 +3445,48 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head if auth_query_params_decrypted: endpoint_url = apply_query_param_auth(endpoint_url, auth_query_params_decrypted) - # Make HTTP request + # Make HTTP request with timeout enforcement logger.info(f"Calling A2A agent '{a2a_agent_name}' at {endpoint_url}") - http_response = await self._http_client.post(endpoint_url, json=request_data, headers=headers) + a2a_start_time = time.time() + try: + http_response = await asyncio.wait_for(self._http_client.post(endpoint_url, json=request_data, headers=headers), timeout=effective_timeout) + except (asyncio.TimeoutError, httpx.TimeoutException): + a2a_elapsed_ms = (time.time() - a2a_start_time) * 1000 + structured_logger.log( + level="WARNING", + message=f"A2A tool invocation timed out: {name}", + component="tool_service", + correlation_id=get_correlation_id(), + duration_ms=a2a_elapsed_ms, + metadata={"event": "tool_timeout", "tool_name": name, "a2a_agent": a2a_agent_name, "timeout_seconds": effective_timeout}, + ) + + # Increment timeout counter + try: + # First-Party + from mcpgateway.services.metrics import tool_timeout_counter # pylint: disable=import-outside-toplevel + + tool_timeout_counter.labels(tool_name=name).inc() + except Exception as exc: + logger.debug("Failed to increment tool_timeout_counter for %s: %s", name, exc, exc_info=True) + + # Trigger circuit breaker on timeout + if self._plugin_manager: + if context_table: + for ctx in context_table.values(): + ctx.set_state("cb_timeout_failure", True) + + if self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): + timeout_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation timed out after {effective_timeout}s")], is_error=True) + await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload=ToolPostInvokePayload(name=name, result=timeout_error_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=False, + ) + + raise ToolTimeoutError(f"Tool invocation timed out after {effective_timeout}s") if http_response.status_code == 200: response_data = http_response.json() @@ -3337,6 +3529,15 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head return tool_result except (PluginError, PluginViolationError): raise + except ToolTimeoutError as e: + # ToolTimeoutError is raised by timeout handlers which already called tool_post_invoke + # Re-raise without calling post_invoke again to avoid double-counting failures + # But DO set error_message and span attributes for observability + error_message = str(e) + if span: + span.set_attribute("error", True) + span.set_attribute("error.message", error_message) + raise except BaseException as e: # Extract root cause from ExceptionGroup (Python 3.11+) # MCP SDK uses TaskGroup which wraps exceptions in ExceptionGroup @@ -3349,6 +3550,22 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head if span: span.set_attribute("error", True) span.set_attribute("error.message", error_message) + + # Notify plugins of the failure so circuit breaker can track it + # This ensures HTTP 4xx/5xx errors and MCP failures are counted + if self._plugin_manager and self._plugin_manager.has_hooks_for(ToolHookType.TOOL_POST_INVOKE): + try: + exception_error_result = ToolResult(content=[TextContent(type="text", text=f"Tool invocation failed: {error_message}")], is_error=True) + await self._plugin_manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, + payload=ToolPostInvokePayload(name=name, result=exception_error_result.model_dump(by_alias=True)), + global_context=global_context, + local_contexts=context_table, + violations_as_exceptions=False, # Don't let plugin errors mask the original exception + ) + except Exception as plugin_exc: + logger.debug("Failed to invoke post-invoke plugins on exception: %s", plugin_exc) + raise ToolInvocationError(f"Tool invocation failed: {error_message}") finally: # Calculate duration diff --git a/plugins/circuit_breaker/README.md b/plugins/circuit_breaker/README.md index d2c7eda57c..70eb772532 100644 --- a/plugins/circuit_breaker/README.md +++ b/plugins/circuit_breaker/README.md @@ -1,12 +1,12 @@ # Circuit Breaker Plugin -Trips a per-tool breaker on high error rates or consecutive failures. Blocks calls during a cooldown period. +Trips a per-tool breaker on high error rates or consecutive failures. Blocks calls during a cooldown period and implements half-open state for recovery testing. -Hooks -- tool_pre_invoke -- tool_post_invoke +## Hooks +- `tool_pre_invoke` - Checks if circuit is open, blocks request or allows through +- `tool_post_invoke` - Records success/failure, evaluates thresholds, updates state -Configuration (example) +## Configuration ```yaml - name: "CircuitBreaker" kind: "plugins.circuit_breaker.circuit_breaker.CircuitBreakerPlugin" @@ -14,14 +14,66 @@ Configuration (example) mode: "enforce_ignore_error" priority: 70 config: - error_rate_threshold: 0.5 - window_seconds: 60 - min_calls: 10 - consecutive_failure_threshold: 5 - cooldown_seconds: 60 - tool_overrides: {} + error_rate_threshold: 0.5 # Fraction of failures to trip breaker (0-1) + window_seconds: 60 # Time window for error rate calculation + min_calls: 10 # Minimum calls before evaluating error rate + consecutive_failure_threshold: 5 # Consecutive failures to trip breaker + cooldown_seconds: 60 # Duration circuit stays open + tool_overrides: {} # Per-tool config overrides ``` -Notes -- Error detection uses ToolResult.is_error when available, or a dict key "is_error". -- Exposes metadata: failure rate, counts, open_until. +## Features + +### Half-Open State +After cooldown expires, the circuit transitions to half-open state: +1. A single test request is allowed through +2. If the test succeeds, the circuit fully closes +3. If the test fails, the circuit immediately reopens for another cooldown + +### Timeout Integration +Tool timeouts are counted as failures when `tool_service` sets the context flag `cb_timeout_failure`. + +### Metadata Exposed +- `circuit_calls_in_window`: Total calls in sliding window +- `circuit_failures_in_window`: Failed calls in window +- `circuit_failure_rate`: Calculated failure rate (0-1) +- `circuit_consecutive_failures`: Current consecutive failure count +- `circuit_open_until`: Unix timestamp when circuit will close (0 if closed) +- `circuit_half_open`: True if in half-open testing state +- `circuit_retry_after_seconds`: Seconds until circuit closes (for retry headers) + +## Notes +- Error detection uses `ToolResult.is_error` or dict keys `is_error`/`isError` (supports both snake_case and camelCase serialization) +- Violation response includes `retry_after_seconds` for rate limiting headers + +## Example Scenario: Unstable Payment Gateway + +**Goal**: Prevent cascading failures when the `payment_api` tool becomes unstable, slow, or times out repeatedly. + +**Configuration**: +```yaml +config: + # 1. Base Strategy (General Tools) + window_seconds: 60 + min_calls: 10 + error_rate_threshold: 0.5 + consecutive_failure_threshold: 5 + cooldown_seconds: 30 + + # 2. Specific Strategy (Critical Payment Tool) + tool_overrides: + payment_api: + consecutive_failure_threshold: 2 + cooldown_seconds: 120 + min_calls: 3 +``` + +**Configuration Breakdown & Reasons:** + +| Parameter | Value | Reason | +|-----------|-------|--------| +| **`window_seconds`** | `60` | **Sliding Window**: We only care about errors in the last minute. Failures from an hour ago shouldn't affect current availability. | +| **`min_calls`** | `10` | **Sample Size**: Prevents tripping on the very first call of the day. We wait for 10 attempts (generic) or 3 (payment) to have statistical confidence before blocking. | +| **`error_rate_threshold`** | `0.5` | **Threshold**: If 50% of calls fail (e.g., 5 out of 10), the service is likely overloaded. Stop sending requests to give it breathing room. | +| **`consecutive_failure_threshold`** | `5` / `2` | **Fast Fail**: Even if the error rate is low, 5 hard failures in a row (e.g., 500 Internal Server Error) means it's down. **Override**: For `payment_api`, we stop after just 2 failures to avoid risking duplicate transactions or bad user experience. | +| **`cooldown_seconds`** | `30` / `120` | **Recovery Time**: Wait 30 seconds before trying again (Half-Open state). **Override**: Payment systems often restart slowly; we give `payment_api` a full 2 minutes (120s) to recover to prevent "flapping" (rapidly opening/closing). | diff --git a/plugins/circuit_breaker/circuit_breaker.py b/plugins/circuit_breaker/circuit_breaker.py index 57d748d414..95cd7f837f 100644 --- a/plugins/circuit_breaker/circuit_breaker.py +++ b/plugins/circuit_breaker/circuit_breaker.py @@ -46,12 +46,18 @@ class _ToolState: calls: Deque of call timestamps within the window. consecutive_failures: Count of consecutive failures. open_until: Unix timestamp when breaker closes; 0 if closed. + half_open: True if breaker is in half-open state (testing recovery). + half_open_in_flight: True if a probe request is currently in progress. + half_open_started: Timestamp when probe started (for stale probe detection). """ failures: Deque[float] calls: Deque[float] consecutive_failures: int open_until: float # epoch when breaker closes; 0 if closed + half_open: bool = False # half-open state for recovery testing + half_open_in_flight: bool = False # True when a probe request is in progress + half_open_started: float = 0.0 # timestamp when probe started class CircuitBreakerConfig(BaseModel): @@ -97,7 +103,7 @@ def _get_state(tool: str) -> _ToolState: """ st = _STATE.get(tool) if not st: - st = _ToolState(failures=deque(), calls=deque(), consecutive_failures=0, open_until=0.0) + st = _ToolState(failures=deque(), calls=deque(), consecutive_failures=0, open_until=0.0, half_open=False, half_open_in_flight=False, half_open_started=0.0) _STATE[tool] = st return st @@ -127,12 +133,16 @@ def _is_error(result: Any) -> bool: Returns: True if result indicates an error, False otherwise. """ - # ToolResult has is_error; otherwise look for common patterns + # ToolResult has is_error; when serialized with by_alias=True it becomes isError try: if hasattr(result, "is_error"): return bool(getattr(result, "is_error")) - if isinstance(result, dict) and "is_error" in result: - return bool(result.get("is_error")) + if isinstance(result, dict): + # Check both snake_case (direct) and camelCase (serialized with by_alias=True) + if "is_error" in result: + return bool(result.get("is_error")) + if "isError" in result: + return bool(result.get("isError")) except Exception: pass return False @@ -162,23 +172,63 @@ async def tool_pre_invoke(self, payload: ToolPreInvokePayload, context: PluginCo """ tool = payload.name st = _get_state(tool) + cfg = _cfg_for(self._cfg, tool) now = _now() - # Close breaker if cooldown elapsed + + # Check if a probe request is already in flight during half-open state + # Only block if we're actually in half-open state (st.half_open is True) + # This prevents wedging if a later plugin blocks after we set half_open_in_flight + if st.half_open and st.half_open_in_flight: + # Check for stale probe (probe started but never completed) + # If probe has been in flight longer than cooldown, reset and allow new probe + probe_timeout = max(1, int(cfg.cooldown_seconds)) + if st.half_open_started and (now - st.half_open_started) > probe_timeout: + # Stale probe detected - reset half-open state and reopen circuit + st.half_open = False + st.half_open_in_flight = False + st.half_open_started = 0.0 + st.open_until = now + probe_timeout # Reopen circuit + else: + # Another probe is already testing the circuit - block this request + return ToolPreInvokeResult( + continue_processing=False, + violation=PluginViolation( + reason="Circuit half-open", + description=f"Probe request in progress for tool '{tool}'", + code="CIRCUIT_HALF_OPEN_PROBE_IN_FLIGHT", + details={"retry_after_seconds": 1.0}, + ), + ) + + # Check if cooldown has elapsed - transition to half-open state if st.open_until and now >= st.open_until: - st.open_until = 0.0 - st.consecutive_failures = 0 + # Transition to half-open state (allow one test request) + st.half_open = True + st.half_open_in_flight = True # Mark probe as in-flight + st.half_open_started = now # Record when probe started for stale detection + st.open_until = 0.0 # Reset open_until so we allow this request through + # Note: consecutive_failures is NOT reset yet - that happens on successful call + + # If still in open state (cooldown not elapsed), block the request if st.open_until and now < st.open_until: + retry_after_seconds = max(0.0, st.open_until - now) return ToolPreInvokeResult( continue_processing=False, violation=PluginViolation( reason="Circuit open", description=f"Breaker open for tool '{tool}' until {int(st.open_until)}", code="CIRCUIT_OPEN", - details={"open_until": st.open_until}, + details={ + "open_until": st.open_until, + "retry_after_seconds": round(retry_after_seconds, 1), + }, ), ) + # Record call timestamp for rate calculations in post hook context context.set_state("cb_call_time", now) + # Track if this is a half-open test request + context.set_state("cb_half_open_test", st.half_open) return ToolPreInvokeResult(continue_processing=True) async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: PluginContext) -> ToolPostInvokeResult: @@ -207,24 +257,66 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin # Record this call start_time = context.get_state("cb_call_time", now) st.calls.append(start_time) + + # Determine if this is an error: + # 1. Check is_error on the result + # 2. Check for timeout flag in context (set by tool_service on timeout) error = _is_error(payload.result) + timeout_occurred = context.get_state("cb_timeout_failure", False) + if timeout_occurred: + error = True + + # Check if this was a half-open test request + was_half_open_test = context.get_state("cb_half_open_test", False) + if error: st.failures.append(start_time) st.consecutive_failures += 1 + + # If this was a half-open test request that failed, immediately reopen the circuit + if was_half_open_test: + st.half_open = False + st.half_open_in_flight = False # Clear probe in-flight flag + st.half_open_started = 0.0 # Clear probe start time + st.open_until = now + max(1, int(cfg.cooldown_seconds)) + try: + from mcpgateway.services.metrics import circuit_breaker_open_counter + circuit_breaker_open_counter.labels(tool_name=tool).inc() + except Exception: + pass else: + # Success - reset consecutive failures st.consecutive_failures = 0 - # Evaluate breaker + # If this was a half-open test request that succeeded, fully close the circuit + if was_half_open_test: + st.half_open = False + st.half_open_in_flight = False # Clear probe in-flight flag + st.half_open_started = 0.0 # Clear probe start time + # Don't reset the window - keep tracking for ongoing health + + # Evaluate breaker (only if not already open from half-open failure) calls = len(st.calls) failure_rate = (len(st.failures) / calls) if calls > 0 else 0.0 should_open = False - if calls >= max(1, int(cfg.min_calls)) and failure_rate >= cfg.error_rate_threshold: - should_open = True - if st.consecutive_failures >= max(1, int(cfg.consecutive_failure_threshold)): - should_open = True - if should_open and not st.open_until: - st.open_until = now + max(1, int(cfg.cooldown_seconds)) + if not st.open_until: # Only evaluate if not already open + if calls >= max(1, int(cfg.min_calls)) and failure_rate >= cfg.error_rate_threshold: + should_open = True + if st.consecutive_failures >= max(1, int(cfg.consecutive_failure_threshold)): + should_open = True + + if should_open: + st.open_until = now + max(1, int(cfg.cooldown_seconds)) + try: + from mcpgateway.services.metrics import circuit_breaker_open_counter + circuit_breaker_open_counter.labels(tool_name=tool).inc() + except Exception: + pass + + # Compute retry_after_seconds for metadata + retry_after_seconds = max(0.0, st.open_until - now) if st.open_until else 0.0 + return ToolPostInvokeResult( metadata={ "circuit_calls_in_window": calls, @@ -232,5 +324,7 @@ async def tool_post_invoke(self, payload: ToolPostInvokePayload, context: Plugin "circuit_failure_rate": round(failure_rate, 3), "circuit_consecutive_failures": st.consecutive_failures, "circuit_open_until": st.open_until or 0.0, + "circuit_half_open": st.half_open, + "circuit_retry_after_seconds": round(retry_after_seconds, 1), } ) diff --git a/tests/unit/mcpgateway/services/test_tool_service.py b/tests/unit/mcpgateway/services/test_tool_service.py index fbbaf24dd1..f71da4217a 100644 --- a/tests/unit/mcpgateway/services/test_tool_service.py +++ b/tests/unit/mcpgateway/services/test_tool_service.py @@ -3612,3 +3612,691 @@ def test_mixed_content_with_anyurl_serialization(self): assert dump["content"][1]["type"] == "resource_link" assert isinstance(dump["content"][1]["uri"], str) assert dump["content"][1]["uri"] == "file:///path/to/document.pdf" + + +# ============================================================================= +# Tool Invocation Timeouts and Circuit Breaker Tests +# ============================================================================= + +class TestToolTimeoutsAndRetries: + """Comprehensive tests for Tool Invocation Timeouts and Circuit Breaker.""" + + def setup_method(self): + """Clear circuit breaker state before each test.""" + from plugins.circuit_breaker.circuit_breaker import _STATE + _STATE.clear() + + @pytest.mark.asyncio + async def test_per_tool_timeout_ms_takes_priority(self): + """Verify per-tool timeout_ms takes priority over global setting.""" + tool_timeout_ms = 5000 # 5 seconds + global_timeout = 60 # 60 seconds + + effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else global_timeout + + assert effective_timeout == 5.0, "Per-tool timeout should take priority" + + @pytest.mark.asyncio + async def test_per_tool_timeout_zero_uses_global(self): + """Verify that timeout_ms=0 falls back to global timeout.""" + tool_timeout_ms = 0 + global_timeout = 60 + + # 0 is falsy, so should fall back to global + effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else global_timeout + + assert effective_timeout == 60, "Zero timeout should fall back to global" + + @pytest.mark.asyncio + async def test_per_tool_timeout_none_uses_global(self): + """Verify that timeout_ms=None falls back to global timeout.""" + tool_timeout_ms = None + global_timeout = 60 + + effective_timeout = (tool_timeout_ms / 1000) if tool_timeout_ms else global_timeout + + assert effective_timeout == 60, "None timeout should fall back to global" + + @pytest.mark.asyncio + async def test_timeout_conversion_from_ms_to_seconds(self): + """Verify correct conversion from milliseconds to seconds.""" + test_cases = [ + (1000, 1.0), + (5000, 5.0), + (30000, 30.0), + (100, 0.1), + (60000, 60.0), + ] + + for timeout_ms, expected_seconds in test_cases: + effective_timeout = timeout_ms / 1000 + assert effective_timeout == expected_seconds, f"Failed for {timeout_ms}ms" + + @pytest.mark.asyncio + async def test_timeout_error_message_includes_duration(self): + """Verify timeout error message includes the timeout duration.""" + for timeout in [5.0, 10.0, 30.0, 60.0]: + error = ToolInvocationError(f"Tool invocation timed out after {timeout}s") + assert str(timeout) in str(error) + assert "timed out" in str(error) + + @pytest.mark.asyncio + async def test_asyncio_timeout_error_behavior(self): + """Test asyncio.TimeoutError is raised correctly after timeout.""" + async def slow_operation(): + await asyncio.sleep(10) + return "completed" + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(slow_operation(), timeout=0.01) + + def test_initial_state_is_closed(self): + """Verify circuit breaker starts in closed state.""" + from plugins.circuit_breaker.circuit_breaker import _get_state + + state = _get_state("test_tool") + + assert state.open_until == 0.0 + assert state.half_open is False + assert state.consecutive_failures == 0 + + def test_state_tracks_calls_in_window(self): + """Verify state tracks call timestamps.""" + import time + from plugins.circuit_breaker.circuit_breaker import _get_state + + state = _get_state("test_tool") + state.calls.append(time.time()) + state.calls.append(time.time()) + + assert len(state.calls) == 2 + + def test_state_tracks_failures_in_window(self): + """Verify state tracks failure timestamps.""" + import time + from plugins.circuit_breaker.circuit_breaker import _get_state + + state = _get_state("test_tool") + state.failures.append(time.time()) + state.failures.append(time.time()) + + assert len(state.failures) == 2 + + def test_consecutive_failures_increment(self): + """Verify consecutive failures counter increments.""" + from plugins.circuit_breaker.circuit_breaker import _get_state + + state = _get_state("test_tool") + state.consecutive_failures += 1 + state.consecutive_failures += 1 + state.consecutive_failures += 1 + + assert state.consecutive_failures == 3 + + def test_consecutive_failures_reset_on_success(self): + """Verify consecutive failures reset to 0 on success.""" + from plugins.circuit_breaker.circuit_breaker import _get_state + + state = _get_state("test_tool") + state.consecutive_failures = 5 + # Simulate success + state.consecutive_failures = 0 + + assert state.consecutive_failures == 0 + + @pytest.mark.asyncio + async def test_half_open_transition_after_cooldown(self): + """Verify transition to half-open state after cooldown expires.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPreInvokePayload + + # Create plugin + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={"cooldown_seconds": 1} + ) + plugin = CircuitBreakerPlugin(config) + + # Open the circuit + state = _get_state("test_tool") + state.open_until = time.time() - 1 # Cooldown expired + state.half_open = False + + # Create payload + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + # Create mock context + context = MagicMock() + context.set_state = MagicMock() + + # Process pre_invoke + result = await plugin.tool_pre_invoke(payload, context) + + # Should allow request through (transition to half-open) + assert result.continue_processing is True + # Verify half-open state was set in context + context.set_state.assert_any_call("cb_half_open_test", True) + + @pytest.mark.asyncio + async def test_half_open_failure_reopens_circuit(self): + """Verify that failure during half-open immediately reopens circuit.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin with short cooldown + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={"cooldown_seconds": 60} + ) + plugin = CircuitBreakerPlugin(config) + + # Set up half-open state + state = _get_state("test_tool") + state.half_open = True + + # Create mock error result + mock_result = MagicMock() + mock_result.is_error = True + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + # Create mock context indicating half-open test + context = MagicMock() + context.get_state = MagicMock(side_effect=lambda k, d=None: { + "cb_call_time": time.time(), + "cb_half_open_test": True, + "cb_timeout_failure": False, + }.get(k, d)) + + # Process post_invoke + result = await plugin.tool_post_invoke(payload, context) + + # Circuit should be reopened + assert state.open_until > time.time() + assert state.half_open is False + + @pytest.mark.asyncio + async def test_half_open_success_closes_circuit(self): + """Verify that success during half-open fully closes circuit.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + # Set up half-open state + state = _get_state("test_tool") + state.half_open = True + state.consecutive_failures = 4 + + # Create mock success result + mock_result = MagicMock() + mock_result.is_error = False + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + # Create mock context indicating half-open test + context = MagicMock() + context.get_state = MagicMock(side_effect=lambda k, d=None: { + "cb_call_time": time.time(), + "cb_half_open_test": True, + "cb_timeout_failure": False, + }.get(k, d)) + + # Process post_invoke + result = await plugin.tool_post_invoke(payload, context) + + # Circuit should be fully closed + assert state.half_open is False + assert state.consecutive_failures == 0 + + @pytest.mark.asyncio + async def test_consecutive_failure_threshold_trips_breaker(self): + """Verify consecutive failures trip the circuit breaker.""" + import time + from collections import deque + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin with low consecutive failure threshold + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={"consecutive_failure_threshold": 3, "cooldown_seconds": 60} + ) + plugin = CircuitBreakerPlugin(config) + + # Pre-set consecutive failures to threshold - 1 + state = _get_state("test_tool") + state.consecutive_failures = 2 + state.calls = deque([time.time()]) + state.failures = deque([time.time()]) + + # Create mock error result + mock_result = MagicMock() + mock_result.is_error = True + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + context = MagicMock() + context.get_state = MagicMock(return_value=None) + + # Process post_invoke - should trip breaker + result = await plugin.tool_post_invoke(payload, context) + + # Circuit should be open + assert state.open_until > time.time() + + @pytest.mark.asyncio + async def test_error_rate_threshold_trips_breaker(self): + """Verify error rate threshold trips the circuit breaker.""" + import time + from collections import deque + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin with specific error rate settings + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={ + "error_rate_threshold": 0.5, # 50% error rate trips + "min_calls": 2, # After 2 calls + "cooldown_seconds": 60, + } + ) + plugin = CircuitBreakerPlugin(config) + + # Pre-set calls and failures for 50% error rate + now = time.time() + state = _get_state("test_tool") + state.calls = deque([now - 1]) # 1 previous call + state.failures = deque([now - 1]) # 1 failure (this will be the 2nd) + state.consecutive_failures = 1 + + # Create mock error result + mock_result = MagicMock() + mock_result.is_error = True + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + context = MagicMock() + context.get_state = MagicMock(side_effect=lambda k, d=None: now if k == "cb_call_time" else d) + + # Process post_invoke - should trip breaker (2/2 = 100% > 50%) + result = await plugin.tool_post_invoke(payload, context) + + # Circuit should be open + assert state.open_until > time.time() + + @pytest.mark.asyncio + async def test_retry_after_seconds_in_violation(self): + """Verify retry_after_seconds is included in violation details.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPreInvokePayload + + # Create plugin + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + # Open the circuit with future close time + state = _get_state("test_tool") + state.open_until = time.time() + 30 # 30 seconds from now + + # Create payload + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + + # Create mock context + context = MagicMock() + + # Process pre_invoke - should block + result = await plugin.tool_pre_invoke(payload, context) + + # Should block with retry_after_seconds + assert result.continue_processing is False + assert result.violation is not None + assert "retry_after_seconds" in result.violation.details + assert 25 <= result.violation.details["retry_after_seconds"] <= 35 + + @pytest.mark.asyncio + async def test_metadata_includes_all_fields(self): + """Verify post_invoke metadata includes all required fields.""" + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + # Create mock success result + mock_result = MagicMock() + mock_result.is_error = False + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + context = MagicMock() + context.get_state = MagicMock(return_value=None) + + # Process post_invoke + result = await plugin.tool_post_invoke(payload, context) + + # Verify all metadata fields are present + required_fields = [ + "circuit_calls_in_window", + "circuit_failures_in_window", + "circuit_failure_rate", + "circuit_consecutive_failures", + "circuit_open_until", + "circuit_half_open", + "circuit_retry_after_seconds", + ] + + for field in required_fields: + assert field in result.metadata, f"Missing field: {field}" + + @pytest.mark.asyncio + async def test_timeout_flag_counts_as_failure(self): + """Verify cb_timeout_failure flag counts as circuit breaker failure.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + # Create mock result that looks successful + mock_result = MagicMock() + mock_result.is_error = False # Result doesn't show error + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + # Create context with timeout flag set + context = MagicMock() + context.get_state = MagicMock(side_effect=lambda k, d=None: { + "cb_call_time": time.time(), + "cb_half_open_test": False, + "cb_timeout_failure": True, # TIMEOUT OCCURRED! + }.get(k, d)) + + # Process post_invoke + result = await plugin.tool_post_invoke(payload, context) + + # Should count as failure + assert result.metadata["circuit_failures_in_window"] == 1 + assert result.metadata["circuit_consecutive_failures"] == 1 + + @pytest.mark.asyncio + async def test_timeout_flag_can_trip_breaker(self): + """Verify enough timeout failures can trip the circuit breaker.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin with low threshold + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={"consecutive_failure_threshold": 3, "cooldown_seconds": 60} + ) + plugin = CircuitBreakerPlugin(config) + + state = _get_state("test_tool") + state.consecutive_failures = 2 # Already at threshold - 1 + + # Create mock result that looks successful + mock_result = MagicMock() + mock_result.is_error = False + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + # Create context with timeout flag set + context = MagicMock() + context.get_state = MagicMock(side_effect=lambda k, d=None: { + "cb_call_time": time.time(), + "cb_half_open_test": False, + "cb_timeout_failure": True, # 3rd consecutive failure via timeout + }.get(k, d)) + + # Process post_invoke + result = await plugin.tool_post_invoke(payload, context) + + # Should trip the breaker + assert state.open_until > time.time() + + def test_tool_overrides_apply_correctly(self): + """Verify per-tool overrides are applied.""" + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerConfig, _cfg_for + ) + + # Create base config with tool overrides + base_config = CircuitBreakerConfig( + error_rate_threshold=0.5, + window_seconds=60, + consecutive_failure_threshold=5, + cooldown_seconds=60, + tool_overrides={ + "critical_tool": { + "consecutive_failure_threshold": 2, # More sensitive + "cooldown_seconds": 120, # Longer cooldown + } + } + ) + + # Get effective config for regular tool + regular_config = _cfg_for(base_config, "regular_tool") + assert regular_config.consecutive_failure_threshold == 5 + assert regular_config.cooldown_seconds == 60 + + # Get effective config for critical tool + critical_config = _cfg_for(base_config, "critical_tool") + assert critical_config.consecutive_failure_threshold == 2 + assert critical_config.cooldown_seconds == 120 + + @pytest.mark.asyncio + async def test_old_entries_evicted_from_window(self): + """Verify old call/failure entries are evicted from sliding window.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state, _STATE + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPostInvokePayload + + # Create plugin with 1-second window + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1, + config={"window_seconds": 1} + ) + plugin = CircuitBreakerPlugin(config) + + # Add old entries + state = _get_state("test_tool") + old_time = time.time() - 10 # 10 seconds ago + state.calls.append(old_time) + state.failures.append(old_time) + + # Create mock result + mock_result = MagicMock() + mock_result.is_error = False + + payload = ToolPostInvokePayload( + name="test_tool", + arguments={}, + result=mock_result + ) + + context = MagicMock() + context.get_state = MagicMock(return_value=None) + + # Process post_invoke - should evict old entries + result = await plugin.tool_post_invoke(payload, context) + + # Old entries should be evicted, new call should be recorded + assert result.metadata["circuit_calls_in_window"] == 1 + assert result.metadata["circuit_failures_in_window"] == 0 + + def test_is_error_with_tool_result_attribute(self): + """Verify is_error detection using ToolResult.is_error attribute.""" + from plugins.circuit_breaker.circuit_breaker import _is_error + + mock_result = MagicMock() + mock_result.is_error = True + + assert _is_error(mock_result) is True + + mock_result.is_error = False + assert _is_error(mock_result) is False + + def test_is_error_with_dict(self): + """Verify is_error detection using dict key.""" + from plugins.circuit_breaker.circuit_breaker import _is_error + + error_dict = {"is_error": True, "content": "error message"} + assert _is_error(error_dict) is True + + success_dict = {"is_error": False, "content": "success"} + assert _is_error(success_dict) is False + + def test_is_error_with_missing_field_returns_false(self): + """Verify is_error returns False when field is missing.""" + from plugins.circuit_breaker.circuit_breaker import _is_error + + # Object without is_error + mock_result = MagicMock(spec=[]) # No attributes + del mock_result.is_error # Remove any auto-mock + + # Dict without is_error key + no_error_dict = {"content": "some content"} + assert _is_error(no_error_dict) is False + + @pytest.mark.asyncio + async def test_plugin_initialization(self): + """Verify plugin initializes correctly with config.""" + from plugins.circuit_breaker.circuit_breaker import CircuitBreakerPlugin + from mcpgateway.plugins.framework import PluginConfig + + config = PluginConfig( + name="CircuitBreaker", + kind="plugins.circuit_breaker.circuit_breaker.CircuitBreakerPlugin", + hooks=["tool_pre_invoke", "tool_post_invoke"], + mode="enforce_ignore_error", + priority=70, + config={ + "error_rate_threshold": 0.3, + "window_seconds": 120, + "min_calls": 5, + "consecutive_failure_threshold": 3, + "cooldown_seconds": 30, + } + ) + + plugin = CircuitBreakerPlugin(config) + + assert plugin._cfg.error_rate_threshold == 0.3 + assert plugin._cfg.window_seconds == 120 + assert plugin._cfg.min_calls == 5 + assert plugin._cfg.consecutive_failure_threshold == 3 + assert plugin._cfg.cooldown_seconds == 30 + + @pytest.mark.asyncio + async def test_plugin_allows_requests_when_closed(self): + """Verify plugin allows requests when circuit is closed.""" + from plugins.circuit_breaker.circuit_breaker import CircuitBreakerPlugin + from mcpgateway.plugins.framework import PluginConfig, ToolPreInvokePayload + + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + context = MagicMock() + context.set_state = MagicMock() + + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is True + assert result.violation is None + + @pytest.mark.asyncio + async def test_plugin_blocks_requests_when_open(self): + """Verify plugin blocks requests when circuit is open.""" + import time + from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, _get_state + ) + from mcpgateway.plugins.framework import PluginConfig, ToolPreInvokePayload + + config = PluginConfig( + name="test", kind="test", hooks=[], mode="enforce", priority=1 + ) + plugin = CircuitBreakerPlugin(config) + + # Open the circuit + state = _get_state("test_tool") + state.open_until = time.time() + 60 # Open for next 60 seconds + + payload = ToolPreInvokePayload(name="test_tool", arguments={}) + context = MagicMock() + + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is False + assert result.violation is not None diff --git a/tests/unit/plugins/test_circuit_breaker.py b/tests/unit/plugins/test_circuit_breaker.py new file mode 100644 index 0000000000..df7674467a --- /dev/null +++ b/tests/unit/plugins/test_circuit_breaker.py @@ -0,0 +1,430 @@ +# -*- coding: utf-8 -*- +"""Tests for Circuit Breaker Plugin. + +Verifies all functionality: +1. Closed state - allows requests +2. Opens on consecutive failures +3. Opens on error rate threshold +4. Blocks requests when open +5. Half-open state - allows probe request +6. Closes on successful probe +7. Reopens on failed probe +8. Timeout failures trigger circuit breaker +9. retry_after_seconds calculation +10. Per-tool configuration overrides +""" + +import asyncio +import pytest +import time +from unittest.mock import MagicMock, patch + +# Import the circuit breaker components +from plugins.circuit_breaker.circuit_breaker import ( + CircuitBreakerPlugin, + CircuitBreakerConfig, + _ToolState, + _STATE, + _get_state, + _cfg_for, + _is_error, +) +from mcpgateway.plugins.framework import ( + PluginConfig, + PluginContext, + ToolPreInvokePayload, + ToolPostInvokePayload, + GlobalContext, +) + + +@pytest.fixture(autouse=True) +def clear_state(): + """Clear circuit breaker state before each test.""" + _STATE.clear() + yield + _STATE.clear() + + +@pytest.fixture +def plugin(): + """Create a circuit breaker plugin with test configuration.""" + config = PluginConfig( + id="test-cb", + kind="circuit_breaker", + name="Test Circuit Breaker", + enabled=True, + order=0, + config={ + "error_rate_threshold": 0.5, + "window_seconds": 60, + "min_calls": 3, + "consecutive_failure_threshold": 3, + "cooldown_seconds": 30, + }, + ) + return CircuitBreakerPlugin(config) + + +@pytest.fixture +def context(): + """Create a plugin context for testing.""" + global_ctx = GlobalContext(request_id="test-request-123") + return PluginContext(plugin_id="test-cb", global_context=global_ctx) + + +class TestCircuitBreakerClosedState: + """Test circuit breaker in closed state.""" + + @pytest.mark.asyncio + async def test_allows_requests_when_closed(self, plugin, context): + """Closed circuit should allow requests through.""" + payload = ToolPreInvokePayload(name="test_tool", args={}) + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is True + assert result.violation is None + + @pytest.mark.asyncio + async def test_records_call_timestamp(self, plugin, context): + """Pre-invoke should record call timestamp in context.""" + payload = ToolPreInvokePayload(name="test_tool", args={}) + await plugin.tool_pre_invoke(payload, context) + + call_time = context.get_state("cb_call_time") + assert call_time is not None + assert abs(call_time - time.time()) < 1 # Within 1 second + + +class TestCircuitBreakerOpening: + """Test circuit breaker opening conditions.""" + + @pytest.mark.asyncio + async def test_opens_on_consecutive_failures(self, plugin, context): + """Circuit should open after consecutive_failure_threshold failures.""" + tool = "test_tool" + + # Simulate 3 consecutive failures + for _ in range(3): + pre_payload = ToolPreInvokePayload(name=tool, args={}) + await plugin.tool_pre_invoke(pre_payload, context) + + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": True}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Check circuit is now open + assert result.metadata["circuit_open_until"] > 0 + assert result.metadata["circuit_consecutive_failures"] == 3 + + @pytest.mark.asyncio + async def test_opens_on_error_rate_threshold(self, plugin, context): + """Circuit should open when error rate exceeds threshold.""" + tool = "test_tool" + + # Simulate 3 calls (min_calls): 2 failures, 1 success = 66% error rate > 50% threshold + for i in range(3): + pre_payload = ToolPreInvokePayload(name=tool, args={}) + await plugin.tool_pre_invoke(pre_payload, context) + + is_error = i < 2 # First 2 are failures + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": is_error}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Check circuit is now open + assert result.metadata["circuit_open_until"] > 0 + assert result.metadata["circuit_failure_rate"] >= 0.5 + + +class TestCircuitBreakerOpenState: + """Test circuit breaker in open state.""" + + @pytest.mark.asyncio + async def test_blocks_requests_when_open(self, plugin, context): + """Open circuit should block requests.""" + tool = "test_tool" + st = _get_state(tool) + st.open_until = time.time() + 30 # Open for 30 seconds + + payload = ToolPreInvokePayload(name=tool, args={}) + result = await plugin.tool_pre_invoke(payload, context) + + assert result.continue_processing is False + assert result.violation is not None + assert result.violation.code == "CIRCUIT_OPEN" + + @pytest.mark.asyncio + async def test_returns_retry_after_seconds(self, plugin, context): + """Open circuit should return retry_after_seconds in violation details.""" + tool = "test_tool" + st = _get_state(tool) + st.open_until = time.time() + 30 # Open for 30 seconds + + payload = ToolPreInvokePayload(name=tool, args={}) + result = await plugin.tool_pre_invoke(payload, context) + + assert result.violation is not None + assert "retry_after_seconds" in result.violation.details + assert result.violation.details["retry_after_seconds"] > 0 + assert result.violation.details["retry_after_seconds"] <= 30 + + +class TestCircuitBreakerHalfOpenState: + """Test circuit breaker half-open state.""" + + @pytest.mark.asyncio + async def test_transitions_to_half_open_after_cooldown(self, plugin, context): + """Circuit should transition to half-open after cooldown.""" + tool = "test_tool" + st = _get_state(tool) + st.open_until = time.time() - 1 # Cooldown elapsed + st.consecutive_failures = 5 + + payload = ToolPreInvokePayload(name=tool, args={}) + result = await plugin.tool_pre_invoke(payload, context) + + # Should allow request through (half-open) + assert result.continue_processing is True + assert st.half_open is True + assert context.get_state("cb_half_open_test") is True + + @pytest.mark.asyncio + async def test_closes_on_successful_probe(self, plugin, context): + """Half-open circuit should close on successful probe.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True + st.consecutive_failures = 5 + + # Set context for half-open test + context.set_state("cb_half_open_test", True) + context.set_state("cb_call_time", time.time()) + + # Successful probe + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Circuit should be fully closed + assert st.half_open is False + assert st.consecutive_failures == 0 + assert result.metadata["circuit_open_until"] == 0.0 + + @pytest.mark.asyncio + async def test_reopens_on_failed_probe(self, plugin, context): + """Half-open circuit should reopen immediately on failed probe.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True + st.consecutive_failures = 5 + + # Set context for half-open test + context.set_state("cb_half_open_test", True) + context.set_state("cb_call_time", time.time()) + + # Failed probe + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": True}) + with patch("mcpgateway.services.metrics.circuit_breaker_open_counter") as mock_counter: + mock_counter.labels.return_value.inc = MagicMock() + result = await plugin.tool_post_invoke(post_payload, context) + + # Circuit should be reopened + assert st.half_open is False + assert result.metadata["circuit_open_until"] > time.time() + + @pytest.mark.asyncio + async def test_blocks_concurrent_probes_during_half_open(self, plugin, context): + """Only one probe should be allowed during half-open state.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True # Must be in half-open state + st.half_open_in_flight = True # Simulate a probe already in progress + st.half_open_started = time.time() # Probe started recently + st.consecutive_failures = 5 + + payload = ToolPreInvokePayload(name=tool, args={}) + result = await plugin.tool_pre_invoke(payload, context) + + # Should block - another probe is in flight + assert result.continue_processing is False + assert result.violation is not None + assert result.violation.code == "CIRCUIT_HALF_OPEN_PROBE_IN_FLIGHT" + + @pytest.mark.asyncio + async def test_stale_probe_detection_resets_half_open(self, plugin, context): + """Stale probe (longer than cooldown) should reset and allow new probe.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True + st.half_open_in_flight = True + st.half_open_started = time.time() - 120 # Probe started 2 minutes ago (stale) + st.consecutive_failures = 5 + + payload = ToolPreInvokePayload(name=tool, args={}) + result = await plugin.tool_pre_invoke(payload, context) + + # Stale probe should be reset, circuit reopened + assert st.half_open is False + assert st.half_open_in_flight is False + assert st.open_until > 0 # Circuit should be reopened + + @pytest.mark.asyncio + async def test_clears_in_flight_flag_on_probe_success(self, plugin, context): + """Successful probe should clear the in-flight flag.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True + st.half_open_in_flight = True + st.consecutive_failures = 5 + + context.set_state("cb_half_open_test", True) + context.set_state("cb_call_time", time.time()) + + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False}) + await plugin.tool_post_invoke(post_payload, context) + + # In-flight flag should be cleared + assert st.half_open_in_flight is False + + @pytest.mark.asyncio + async def test_clears_in_flight_flag_on_probe_failure(self, plugin, context): + """Failed probe should clear the in-flight flag.""" + tool = "test_tool" + st = _get_state(tool) + st.half_open = True + st.half_open_in_flight = True + st.consecutive_failures = 5 + + context.set_state("cb_half_open_test", True) + context.set_state("cb_call_time", time.time()) + + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": True}) + with patch("mcpgateway.services.metrics.circuit_breaker_open_counter") as mock_counter: + mock_counter.labels.return_value.inc = MagicMock() + await plugin.tool_post_invoke(post_payload, context) + + # In-flight flag should be cleared + assert st.half_open_in_flight is False + + +class TestTimeoutIntegration: + """Test timeout integration with circuit breaker.""" + + @pytest.mark.asyncio + async def test_timeout_counted_as_failure(self, plugin, context): + """Timeout flag should be counted as failure.""" + tool = "test_tool" + + # Set timeout flag (as tool_service would do) + context.set_state("cb_timeout_failure", True) + context.set_state("cb_call_time", time.time()) + + # Post-invoke with a technically successful result but timeout flag set + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Should count as failure + assert result.metadata["circuit_failures_in_window"] == 1 + assert result.metadata["circuit_consecutive_failures"] == 1 + + +class TestPerToolOverrides: + """Test per-tool configuration overrides.""" + + @pytest.mark.asyncio + async def test_tool_override_applied(self, context): + """Per-tool overrides should be applied correctly.""" + config = PluginConfig( + id="test-cb", + kind="circuit_breaker", + name="Test Circuit Breaker", + enabled=True, + order=0, + config={ + "consecutive_failure_threshold": 5, + "tool_overrides": { + "critical_tool": {"consecutive_failure_threshold": 10} + }, + }, + ) + plugin = CircuitBreakerPlugin(config) + + # Simulate 5 failures on critical_tool (should NOT open - needs 10) + for _ in range(5): + pre_payload = ToolPreInvokePayload(name="critical_tool", args={}) + await plugin.tool_pre_invoke(pre_payload, context) + + post_payload = ToolPostInvokePayload(name="critical_tool", result={"is_error": True}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Circuit should still be closed (needs 10 failures) + assert result.metadata["circuit_open_until"] == 0.0 + assert result.metadata["circuit_consecutive_failures"] == 5 + + +class TestHelperFunctions: + """Test helper functions.""" + + def test_is_error_with_dict(self): + """_is_error should detect error in dict result.""" + assert _is_error({"is_error": True}) is True + assert _is_error({"is_error": False}) is False + assert _is_error({"success": True}) is False + + def test_is_error_with_camel_case(self): + """_is_error should detect error in camelCase (serialized via by_alias=True).""" + # When ToolResult.model_dump(by_alias=True) is used, is_error becomes isError + assert _is_error({"isError": True}) is True + assert _is_error({"isError": False}) is False + # snake_case takes precedence if both present + assert _is_error({"is_error": True, "isError": False}) is True + + def test_is_error_with_object(self): + """_is_error should detect error in object result.""" + class MockResult: + is_error = True + + assert _is_error(MockResult()) is True + MockResult.is_error = False + assert _is_error(MockResult()) is False + + def test_cfg_for_with_override(self): + """_cfg_for should merge tool overrides.""" + base_cfg = CircuitBreakerConfig( + consecutive_failure_threshold=5, + tool_overrides={"special_tool": {"consecutive_failure_threshold": 10}}, + ) + + merged = _cfg_for(base_cfg, "special_tool") + assert merged.consecutive_failure_threshold == 10 + + default = _cfg_for(base_cfg, "regular_tool") + assert default.consecutive_failure_threshold == 5 + + +class TestWindowEviction: + """Test time window eviction logic.""" + + @pytest.mark.asyncio + async def test_old_entries_evicted(self, plugin, context): + """Old call/failure entries should be evicted after window expires.""" + tool = "test_tool" + st = _get_state(tool) + + # Add old entries (outside window) + old_time = time.time() - 120 # 2 minutes ago + st.calls.append(old_time) + st.failures.append(old_time) + + # Make a new call + pre_payload = ToolPreInvokePayload(name=tool, args={}) + await plugin.tool_pre_invoke(pre_payload, context) + + post_payload = ToolPostInvokePayload(name=tool, result={"is_error": False}) + result = await plugin.tool_post_invoke(post_payload, context) + + # Old entries should be evicted + assert result.metadata["circuit_calls_in_window"] == 1 + assert result.metadata["circuit_failures_in_window"] == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])