Skip to content

Commit c087d95

Browse files
committed
add stateful sessions in http
Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent 2cd515b commit c087d95

File tree

10 files changed

+953
-5
lines changed

10 files changed

+953
-5
lines changed

.env.example

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,13 +1770,23 @@ PLUGINS_CLI_MARKUP_MODE=rich
17701770
# Enable stateful sessions (stores session state server-side)
17711771
# Options: true, false (default)
17721772
# false: Stateless mode (better for scaling)
1773+
# true: Stateful mode (requires CACHE_TYPE=redis for multi-worker deployments)
17731774
# USE_STATEFUL_SESSIONS=false
17741775

17751776
# Enable JSON response format for streaming HTTP
17761777
# Options: true (default), false
17771778
# true: Return JSON responses, false: Return SSE stream
17781779
# JSON_RESPONSE_ENABLED=true
17791780

1781+
# Event store configuration for stateful sessions
1782+
# Ring buffer size per stream (default: 100)
1783+
# Controls how many events are kept in memory before oldest are evicted
1784+
# STREAMABLE_HTTP_MAX_EVENTS_PER_STREAM=100
1785+
1786+
# Stream TTL in seconds (default: 3600 = 1 hour)
1787+
# How long event streams are kept in Redis before automatic cleanup
1788+
# STREAMABLE_HTTP_EVENT_TTL=3600
1789+
17801790
# Federation Configuration
17811791

17821792
# Timeout for federation requests in seconds

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ services:
236236
- SECURITY_HEADERS_ENABLED=true
237237
- CORS_ALLOW_CREDENTIALS=true
238238
- SECURE_COOKIES=false
239+
- USE_STATEFUL_SESSIONS=true
239240
## Uncomment to enable HTTPS (run `make certs` first)
240241
# - SSL=true
241242
# - CERT_FILE=/app/certs/cert.pem

infra/nginx/nginx.conf

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,42 @@ http {
492492
proxy_read_timeout 1h;
493493
}
494494

495+
# General SSE endpoint (without server prefix)
496+
location = /sse {
497+
proxy_pass http://gateway_backend;
498+
499+
# SSE-specific headers
500+
proxy_set_header Connection '';
501+
proxy_http_version 1.1;
502+
chunked_transfer_encoding off;
503+
proxy_buffering off;
504+
proxy_cache off;
505+
506+
# Proxy headers
507+
proxy_set_header Host $http_host;
508+
proxy_set_header X-Real-IP $remote_addr;
509+
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
510+
proxy_set_header X-Forwarded-Proto $scheme;
511+
512+
# Extended timeouts for SSE
513+
proxy_connect_timeout 1h;
514+
proxy_send_timeout 1h;
515+
proxy_read_timeout 1h;
516+
}
517+
518+
# Message endpoint for SSE clients
519+
location = /message {
520+
proxy_pass http://gateway_backend;
521+
522+
proxy_buffering off;
523+
proxy_cache off;
524+
525+
proxy_set_header Host $http_host;
526+
proxy_set_header X-Real-IP $remote_addr;
527+
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
528+
proxy_set_header X-Forwarded-Proto $scheme;
529+
}
530+
495531
location ~ ^/servers/.*/ws$ {
496532
proxy_pass http://gateway_backend;
497533

mcpgateway/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,8 @@ def parse_issuers(cls, v: Any) -> list[str]:
15011501
# streamable http transport
15021502
use_stateful_sessions: bool = False # Set to False to use stateless sessions without event store
15031503
json_response_enabled: bool = True # Enable JSON responses instead of SSE streams
1504+
streamable_http_max_events_per_stream: int = 100 # Ring buffer capacity per stream
1505+
streamable_http_event_ttl: int = 3600 # Event stream TTL in seconds (1 hour)
15041506

15051507
# Core plugin settings
15061508
plugins_enabled: bool = Field(default=False, description="Enable the plugin framework")

mcpgateway/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5321,7 +5321,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
53215321
from mcpgateway.services.mcp_session_pool import WORKER_ID # pylint: disable=import-outside-toplevel
53225322

53235323
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
5324-
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | RPC request received, checking affinity")
5324+
print(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | RPC request received, checking affinity")
53255325
try:
53265326
# First-Party
53275327
from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
@@ -5333,7 +5333,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
53335333
)
53345334
if forwarded_response is not None:
53355335
# Request was handled by another worker
5336-
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded response received")
5336+
print(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Forwarded response received")
53375337
if "error" in forwarded_response:
53385338
raise JSONRPCError(
53395339
forwarded_response["error"].get("code", -32603),
@@ -5349,7 +5349,7 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user=Depen
53495349
from mcpgateway.services.mcp_session_pool import WORKER_ID # pylint: disable=import-outside-toplevel
53505350

53515351
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
5352-
logger.info(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Internally forwarded request, executing locally")
5352+
print(f"[AFFINITY] Worker {WORKER_ID} | Session {session_short}... | Method: {method} | Internally forwarded request, executing locally")
53535353

53545354
if method == "initialize":
53555355
# Extract session_id from params or query string (for capability tracking)

mcpgateway/services/tool_service.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,7 @@ async def invoke_tool(
26262626
gateway_ca_cert_sig = gateway_payload.get("ca_certificate_sig") if has_gateway else None
26272627
gateway_passthrough = gateway_payload.get("passthrough_headers") if has_gateway else None
26282628
gateway_id_str = gateway_payload.get("id") if has_gateway else None
2629+
gateway_transport = gateway_payload.get("transport") if has_gateway else None
26292630

26302631
# Decrypt and apply query param auth to URL if applicable
26312632
gateway_auth_query_params_decrypted: Optional[Dict[str, str]] = None
@@ -2811,7 +2812,9 @@ async def invoke_tool(
28112812
mcp_session_id = request_headers_lower.get("mcp-session-id")
28122813
if mcp_session_id:
28132814
headers["x-mcp-session-id"] = mcp_session_id
2815+
# Standard
28142816
import os # pylint: disable=import-outside-toplevel
2817+
28152818
worker_id = str(os.getpid())
28162819
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
28172820
logger.info(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {tool_name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity")
@@ -2950,7 +2953,9 @@ async def invoke_tool(
29502953
mcp_session_id = request_headers_lower.get("mcp-session-id")
29512954
if mcp_session_id:
29522955
headers["x-mcp-session-id"] = mcp_session_id
2956+
# Standard
29532957
import os # pylint: disable=import-outside-toplevel
2958+
29542959
worker_id = str(os.getpid())
29552960
session_short = mcp_session_id[:8] if len(mcp_session_id) >= 8 else mcp_session_id
29562961
logger.info(f"[AFFINITY] Worker {worker_id} | Session {session_short}... | Tool: {tool_name} | Normalized MCP-Session-Id → x-mcp-session-id for pool affinity (MCP transport)")
@@ -3163,10 +3168,12 @@ async def connect_to_streamablehttp_server(server_url: str, headers: dict = head
31633168

31643169
if use_pool and pool is not None:
31653170
# Pooled path: do NOT add per-request headers (they would be pinned)
3171+
# Determine transport type based on current transport setting
3172+
pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP
31663173
async with pool.session(
31673174
url=server_url,
31683175
headers=headers,
3169-
transport_type=TransportType.STREAMABLE_HTTP,
3176+
transport_type=pool_transport_type,
31703177
httpx_client_factory=get_httpx_client_factory,
31713178
user_identity=app_user_email,
31723179
gateway_id=gateway_id_str,
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""
2+
Redis-backed event store for Streamable HTTP stateful sessions.
3+
4+
Provides distributed event storage across multiple gateway workers using Redis,
5+
enabling stateful MCP sessions to work correctly behind load balancers.
6+
7+
Architecture:
8+
- Uses Redis Sorted Sets for ordered event storage with ring buffer semantics
9+
- Events indexed by event_id for O(1) lookup during replay
10+
- Automatic eviction when exceeding max_events_per_stream
11+
- TTL-based cleanup for expired streams
12+
"""
13+
14+
# Standard
15+
import logging
16+
from typing import TYPE_CHECKING
17+
import uuid
18+
19+
# Third-Party
20+
from mcp.server.streamable_http import EventCallback, EventStore
21+
from mcp.types import JSONRPCMessage
22+
import orjson
23+
24+
# First-Party
25+
from mcpgateway.utils.redis_client import get_redis_client
26+
27+
if TYPE_CHECKING:
28+
# Third-Party
29+
from redis.asyncio import Redis
30+
31+
logger = logging.getLogger(__name__)
32+
33+
34+
class RedisEventStore(EventStore):
35+
"""
36+
Redis-backed event store for multi-worker deployments.
37+
38+
Data Model:
39+
Per Stream:
40+
- Hash: mcpgw:eventstore:{stream_id}:meta
41+
- start_seq: Oldest sequence number (for eviction detection)
42+
- next_seq: Next sequence number to assign
43+
- count: Current event count
44+
45+
- Sorted Set: mcpgw:eventstore:{stream_id}:events
46+
- Score: sequence number
47+
- Value: JSON {event_id, message, seq_num}
48+
49+
Global:
50+
- Hash: mcpgw:eventstore:event_index
51+
- Key: event_id -> Value: JSON {stream_id, seq_num}
52+
53+
Examples:
54+
>>> # Create event store with custom settings
55+
>>> store = RedisEventStore(max_events_per_stream=200, ttl=7200)
56+
57+
>>> # Store an event
58+
>>> event_id = await store.store_event("stream-123", message)
59+
60+
>>> # Replay events after a specific event_id
61+
>>> async def callback(msg):
62+
... print(f"Replayed: {msg}")
63+
>>> stream_id = await store.replay_events_after(event_id, callback)
64+
"""
65+
66+
def __init__(self, max_events_per_stream: int = 100, ttl: int = 3600):
67+
"""
68+
Initialize Redis event store.
69+
70+
Args:
71+
max_events_per_stream: Maximum events per stream (ring buffer size)
72+
ttl: Stream TTL in seconds (default 1 hour)
73+
"""
74+
self.max_events = max_events_per_stream
75+
self.ttl = ttl
76+
logger.info(f"RedisEventStore initialized: max_events={max_events_per_stream}, ttl={ttl}s")
77+
78+
def _get_stream_meta_key(self, stream_id: str) -> str:
79+
"""Get Redis key for stream metadata."""
80+
return f"mcpgw:eventstore:{stream_id}:meta"
81+
82+
def _get_stream_events_key(self, stream_id: str) -> str:
83+
"""Get Redis key for stream events sorted set."""
84+
return f"mcpgw:eventstore:{stream_id}:events"
85+
86+
def _get_event_index_key(self) -> str:
87+
"""Get Redis key for global event index."""
88+
return "mcpgw:eventstore:event_index"
89+
90+
async def store_event(self, stream_id: str, message: JSONRPCMessage | None) -> str:
91+
"""
92+
Store an event in Redis.
93+
94+
Args:
95+
stream_id: Unique stream identifier
96+
message: JSON-RPC message to store (None for priming events)
97+
98+
Returns:
99+
Unique event_id for this event
100+
101+
Examples:
102+
>>> event_id = await store.store_event("stream-123", {"jsonrpc": "2.0", "method": "test"})
103+
>>> isinstance(event_id, str)
104+
True
105+
"""
106+
redis: Redis = await get_redis_client()
107+
event_id = str(uuid.uuid4())
108+
109+
logger.info(f"[REDIS_EVENTSTORE] Storing event | stream_id={stream_id} | event_id={event_id} | message_type={type(message).__name__ if message else 'None'}")
110+
111+
meta_key = self._get_stream_meta_key(stream_id)
112+
events_key = self._get_stream_events_key(stream_id)
113+
index_key = self._get_event_index_key()
114+
115+
# Atomically increment sequence number
116+
seq_num = await redis.hincrby(meta_key, "next_seq", 1)
117+
118+
# Convert message to dict for serialization (Pydantic model -> dict)
119+
message_dict = None if message is None else (message.model_dump() if hasattr(message, "model_dump") else dict(message))
120+
121+
# Serialize event data
122+
event_data = orjson.dumps({"event_id": event_id, "message": message_dict, "seq_num": seq_num})
123+
124+
# Store event in sorted set (score = seq_num)
125+
await redis.zadd(events_key, {event_data: seq_num})
126+
127+
# Index event_id for lookup
128+
index_data = orjson.dumps({"stream_id": stream_id, "seq_num": seq_num})
129+
await redis.hset(index_key, event_id, index_data)
130+
131+
# Increment count
132+
count = await redis.hincrby(meta_key, "count", 1)
133+
134+
# Handle eviction if exceeding max_events
135+
if count > self.max_events:
136+
# Calculate how many to evict
137+
to_evict = count - self.max_events
138+
139+
# Get events to evict (oldest by rank)
140+
evicted = await redis.zrange(events_key, 0, to_evict - 1)
141+
142+
# Remove from sorted set
143+
await redis.zremrangebyrank(events_key, 0, to_evict - 1)
144+
145+
# Remove from event index and update start_seq
146+
for event_bytes in evicted:
147+
evicted_event = orjson.loads(event_bytes)
148+
await redis.hdel(index_key, evicted_event["event_id"])
149+
150+
# Update start_seq to first remaining event
151+
remaining = await redis.zrange(events_key, 0, 0, withscores=True)
152+
if remaining:
153+
_, start_seq = remaining[0]
154+
await redis.hset(meta_key, "start_seq", int(start_seq))
155+
156+
# Update count
157+
await redis.hset(meta_key, "count", self.max_events)
158+
159+
# Set TTL on stream keys
160+
await redis.expire(meta_key, self.ttl)
161+
await redis.expire(events_key, self.ttl)
162+
163+
logger.debug(f"Stored event {event_id} in stream {stream_id} (seq={seq_num})")
164+
return event_id
165+
166+
async def replay_events_after(self, last_event_id: str, send_callback: EventCallback) -> str | None:
167+
"""
168+
Replay events after a specific event_id.
169+
170+
Args:
171+
last_event_id: Event ID to replay from
172+
send_callback: Async callback to receive replayed messages
173+
174+
Returns:
175+
stream_id if found, None if event not found or evicted
176+
177+
Examples:
178+
>>> messages = []
179+
>>> async def callback(msg):
180+
... messages.append(msg)
181+
>>> stream_id = await store.replay_events_after(event_id, callback)
182+
>>> len(messages) > 0
183+
True
184+
"""
185+
redis: Redis = await get_redis_client()
186+
index_key = self._get_event_index_key()
187+
188+
logger.info(f"[REDIS_EVENTSTORE] Replaying events | last_event_id={last_event_id}")
189+
190+
# Lookup event in index
191+
index_data = await redis.hget(index_key, last_event_id)
192+
if not index_data:
193+
logger.warning(f"[REDIS_EVENTSTORE] Event not found in index | last_event_id={last_event_id}")
194+
return None
195+
196+
event_info = orjson.loads(index_data)
197+
stream_id = event_info["stream_id"]
198+
last_seq = event_info["seq_num"]
199+
200+
meta_key = self._get_stream_meta_key(stream_id)
201+
events_key = self._get_stream_events_key(stream_id)
202+
203+
# Check if event still in buffer (not evicted)
204+
start_seq_bytes = await redis.hget(meta_key, "start_seq")
205+
if start_seq_bytes:
206+
start_seq = int(start_seq_bytes)
207+
if last_seq < start_seq:
208+
logger.warning(f"Event {last_event_id} evicted from stream {stream_id} (seq {last_seq} < start {start_seq})")
209+
return None
210+
211+
# Get all events after last_seq
212+
events = await redis.zrangebyscore(events_key, last_seq + 1, "+inf")
213+
214+
# Replay events
215+
for event_bytes in events:
216+
event_data = orjson.loads(event_bytes)
217+
message = event_data["message"]
218+
await send_callback(message)
219+
220+
logger.info(f"[REDIS_EVENTSTORE] Replayed events | stream_id={stream_id} | last_event_id={last_event_id} | count={len(events)}")
221+
return stream_id

0 commit comments

Comments
 (0)