|
| 1 | +""" |
| 2 | +Redis-backed event store for Streamable HTTP stateful sessions. |
| 3 | +
|
| 4 | +Provides distributed event storage across multiple gateway workers using Redis, |
| 5 | +enabling stateful MCP sessions to work correctly behind load balancers. |
| 6 | +
|
| 7 | +Architecture: |
| 8 | +- Uses Redis Sorted Sets for ordered event storage with ring buffer semantics |
| 9 | +- Events indexed by event_id for O(1) lookup during replay |
| 10 | +- Automatic eviction when exceeding max_events_per_stream |
| 11 | +- TTL-based cleanup for expired streams |
| 12 | +""" |
| 13 | + |
| 14 | +# Standard |
| 15 | +import logging |
| 16 | +from typing import TYPE_CHECKING |
| 17 | +import uuid |
| 18 | + |
| 19 | +# Third-Party |
| 20 | +from mcp.server.streamable_http import EventCallback, EventStore |
| 21 | +from mcp.types import JSONRPCMessage |
| 22 | +import orjson |
| 23 | + |
| 24 | +# First-Party |
| 25 | +from mcpgateway.utils.redis_client import get_redis_client |
| 26 | + |
| 27 | +if TYPE_CHECKING: |
| 28 | + # Third-Party |
| 29 | + from redis.asyncio import Redis |
| 30 | + |
| 31 | +logger = logging.getLogger(__name__) |
| 32 | + |
| 33 | + |
| 34 | +class RedisEventStore(EventStore): |
| 35 | + """ |
| 36 | + Redis-backed event store for multi-worker deployments. |
| 37 | +
|
| 38 | + Data Model: |
| 39 | + Per Stream: |
| 40 | + - Hash: mcpgw:eventstore:{stream_id}:meta |
| 41 | + - start_seq: Oldest sequence number (for eviction detection) |
| 42 | + - next_seq: Next sequence number to assign |
| 43 | + - count: Current event count |
| 44 | +
|
| 45 | + - Sorted Set: mcpgw:eventstore:{stream_id}:events |
| 46 | + - Score: sequence number |
| 47 | + - Value: JSON {event_id, message, seq_num} |
| 48 | +
|
| 49 | + Global: |
| 50 | + - Hash: mcpgw:eventstore:event_index |
| 51 | + - Key: event_id -> Value: JSON {stream_id, seq_num} |
| 52 | +
|
| 53 | + Examples: |
| 54 | + >>> # Create event store with custom settings |
| 55 | + >>> store = RedisEventStore(max_events_per_stream=200, ttl=7200) |
| 56 | +
|
| 57 | + >>> # Store an event |
| 58 | + >>> event_id = await store.store_event("stream-123", message) |
| 59 | +
|
| 60 | + >>> # Replay events after a specific event_id |
| 61 | + >>> async def callback(msg): |
| 62 | + ... print(f"Replayed: {msg}") |
| 63 | + >>> stream_id = await store.replay_events_after(event_id, callback) |
| 64 | + """ |
| 65 | + |
| 66 | + def __init__(self, max_events_per_stream: int = 100, ttl: int = 3600): |
| 67 | + """ |
| 68 | + Initialize Redis event store. |
| 69 | +
|
| 70 | + Args: |
| 71 | + max_events_per_stream: Maximum events per stream (ring buffer size) |
| 72 | + ttl: Stream TTL in seconds (default 1 hour) |
| 73 | + """ |
| 74 | + self.max_events = max_events_per_stream |
| 75 | + self.ttl = ttl |
| 76 | + logger.info(f"RedisEventStore initialized: max_events={max_events_per_stream}, ttl={ttl}s") |
| 77 | + |
| 78 | + def _get_stream_meta_key(self, stream_id: str) -> str: |
| 79 | + """Get Redis key for stream metadata.""" |
| 80 | + return f"mcpgw:eventstore:{stream_id}:meta" |
| 81 | + |
| 82 | + def _get_stream_events_key(self, stream_id: str) -> str: |
| 83 | + """Get Redis key for stream events sorted set.""" |
| 84 | + return f"mcpgw:eventstore:{stream_id}:events" |
| 85 | + |
| 86 | + def _get_event_index_key(self) -> str: |
| 87 | + """Get Redis key for global event index.""" |
| 88 | + return "mcpgw:eventstore:event_index" |
| 89 | + |
| 90 | + async def store_event(self, stream_id: str, message: JSONRPCMessage | None) -> str: |
| 91 | + """ |
| 92 | + Store an event in Redis. |
| 93 | +
|
| 94 | + Args: |
| 95 | + stream_id: Unique stream identifier |
| 96 | + message: JSON-RPC message to store (None for priming events) |
| 97 | +
|
| 98 | + Returns: |
| 99 | + Unique event_id for this event |
| 100 | +
|
| 101 | + Examples: |
| 102 | + >>> event_id = await store.store_event("stream-123", {"jsonrpc": "2.0", "method": "test"}) |
| 103 | + >>> isinstance(event_id, str) |
| 104 | + True |
| 105 | + """ |
| 106 | + redis: Redis = await get_redis_client() |
| 107 | + event_id = str(uuid.uuid4()) |
| 108 | + |
| 109 | + logger.info(f"[REDIS_EVENTSTORE] Storing event | stream_id={stream_id} | event_id={event_id} | message_type={type(message).__name__ if message else 'None'}") |
| 110 | + |
| 111 | + meta_key = self._get_stream_meta_key(stream_id) |
| 112 | + events_key = self._get_stream_events_key(stream_id) |
| 113 | + index_key = self._get_event_index_key() |
| 114 | + |
| 115 | + # Atomically increment sequence number |
| 116 | + seq_num = await redis.hincrby(meta_key, "next_seq", 1) |
| 117 | + |
| 118 | + # Convert message to dict for serialization (Pydantic model -> dict) |
| 119 | + message_dict = None if message is None else (message.model_dump() if hasattr(message, "model_dump") else dict(message)) |
| 120 | + |
| 121 | + # Serialize event data |
| 122 | + event_data = orjson.dumps({"event_id": event_id, "message": message_dict, "seq_num": seq_num}) |
| 123 | + |
| 124 | + # Store event in sorted set (score = seq_num) |
| 125 | + await redis.zadd(events_key, {event_data: seq_num}) |
| 126 | + |
| 127 | + # Index event_id for lookup |
| 128 | + index_data = orjson.dumps({"stream_id": stream_id, "seq_num": seq_num}) |
| 129 | + await redis.hset(index_key, event_id, index_data) |
| 130 | + |
| 131 | + # Increment count |
| 132 | + count = await redis.hincrby(meta_key, "count", 1) |
| 133 | + |
| 134 | + # Handle eviction if exceeding max_events |
| 135 | + if count > self.max_events: |
| 136 | + # Calculate how many to evict |
| 137 | + to_evict = count - self.max_events |
| 138 | + |
| 139 | + # Get events to evict (oldest by rank) |
| 140 | + evicted = await redis.zrange(events_key, 0, to_evict - 1) |
| 141 | + |
| 142 | + # Remove from sorted set |
| 143 | + await redis.zremrangebyrank(events_key, 0, to_evict - 1) |
| 144 | + |
| 145 | + # Remove from event index and update start_seq |
| 146 | + for event_bytes in evicted: |
| 147 | + evicted_event = orjson.loads(event_bytes) |
| 148 | + await redis.hdel(index_key, evicted_event["event_id"]) |
| 149 | + |
| 150 | + # Update start_seq to first remaining event |
| 151 | + remaining = await redis.zrange(events_key, 0, 0, withscores=True) |
| 152 | + if remaining: |
| 153 | + _, start_seq = remaining[0] |
| 154 | + await redis.hset(meta_key, "start_seq", int(start_seq)) |
| 155 | + |
| 156 | + # Update count |
| 157 | + await redis.hset(meta_key, "count", self.max_events) |
| 158 | + |
| 159 | + # Set TTL on stream keys |
| 160 | + await redis.expire(meta_key, self.ttl) |
| 161 | + await redis.expire(events_key, self.ttl) |
| 162 | + |
| 163 | + logger.debug(f"Stored event {event_id} in stream {stream_id} (seq={seq_num})") |
| 164 | + return event_id |
| 165 | + |
| 166 | + async def replay_events_after(self, last_event_id: str, send_callback: EventCallback) -> str | None: |
| 167 | + """ |
| 168 | + Replay events after a specific event_id. |
| 169 | +
|
| 170 | + Args: |
| 171 | + last_event_id: Event ID to replay from |
| 172 | + send_callback: Async callback to receive replayed messages |
| 173 | +
|
| 174 | + Returns: |
| 175 | + stream_id if found, None if event not found or evicted |
| 176 | +
|
| 177 | + Examples: |
| 178 | + >>> messages = [] |
| 179 | + >>> async def callback(msg): |
| 180 | + ... messages.append(msg) |
| 181 | + >>> stream_id = await store.replay_events_after(event_id, callback) |
| 182 | + >>> len(messages) > 0 |
| 183 | + True |
| 184 | + """ |
| 185 | + redis: Redis = await get_redis_client() |
| 186 | + index_key = self._get_event_index_key() |
| 187 | + |
| 188 | + logger.info(f"[REDIS_EVENTSTORE] Replaying events | last_event_id={last_event_id}") |
| 189 | + |
| 190 | + # Lookup event in index |
| 191 | + index_data = await redis.hget(index_key, last_event_id) |
| 192 | + if not index_data: |
| 193 | + logger.warning(f"[REDIS_EVENTSTORE] Event not found in index | last_event_id={last_event_id}") |
| 194 | + return None |
| 195 | + |
| 196 | + event_info = orjson.loads(index_data) |
| 197 | + stream_id = event_info["stream_id"] |
| 198 | + last_seq = event_info["seq_num"] |
| 199 | + |
| 200 | + meta_key = self._get_stream_meta_key(stream_id) |
| 201 | + events_key = self._get_stream_events_key(stream_id) |
| 202 | + |
| 203 | + # Check if event still in buffer (not evicted) |
| 204 | + start_seq_bytes = await redis.hget(meta_key, "start_seq") |
| 205 | + if start_seq_bytes: |
| 206 | + start_seq = int(start_seq_bytes) |
| 207 | + if last_seq < start_seq: |
| 208 | + logger.warning(f"Event {last_event_id} evicted from stream {stream_id} (seq {last_seq} < start {start_seq})") |
| 209 | + return None |
| 210 | + |
| 211 | + # Get all events after last_seq |
| 212 | + events = await redis.zrangebyscore(events_key, last_seq + 1, "+inf") |
| 213 | + |
| 214 | + # Replay events |
| 215 | + for event_bytes in events: |
| 216 | + event_data = orjson.loads(event_bytes) |
| 217 | + message = event_data["message"] |
| 218 | + await send_callback(message) |
| 219 | + |
| 220 | + logger.info(f"[REDIS_EVENTSTORE] Replayed events | stream_id={stream_id} | last_event_id={last_event_id} | count={len(events)}") |
| 221 | + return stream_id |
0 commit comments