diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 1aa7e42e2..829d28e69 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -28,6 +28,7 @@ - [Test Failures in CI](./operations/testfailures.md) - [Configs](./operations/configs.md) - [Elasticsearch](./operations/elasticsearch.md) + - [Curated Recommendations](./operations/curated-recommendations/corpus-cache.md) - [Jobs](./operations/jobs.md) - [Navigational Suggestions](./operations/jobs/navigational_suggestions.md) - [Dynamic Wikipedia Indexer](./operations/jobs/dynamic-wiki-indexer.md) diff --git a/docs/operations/curated-recommendations/corpus-cache.md b/docs/operations/curated-recommendations/corpus-cache.md new file mode 100644 index 000000000..eb553f9a3 --- /dev/null +++ b/docs/operations/curated-recommendations/corpus-cache.md @@ -0,0 +1,114 @@ +# Corpus Cache (Redis L2) + +Shared Redis cache between the per-pod in-memory cache and the Corpus GraphQL API. + +## Why + +Merino pods each independently fetch from the Corpus API on a short interval. This puts unnecessary load on Apollo/Client-API and creates risk as we expand internationally or scale pod count. + +## How it works + +```mermaid +flowchart TB + req["Firefox NewTab Request"] + + subgraph L1 ["L1 — Per-Pod In-Memory SWR"] + check_l1{{"Check in-memory cache"}} + end + + respond_fresh["Respond with fresh data"] + respond_stale["Respond with stale data"] + + subgraph bg ["Background Revalidation Task"] + direction TB + + subgraph L2 ["L2 — Shared Redis"] + check_l2{{"Check Redis cache"}} + acquire_lock{{"Try distributed lock"}} + end + + api["Fetch from Corpus GraphQL API"] + write["Write to Redis + release lock + update L1"] + end + + req --> check_l1 + + check_l1 -- "FRESH HIT" --> respond_fresh + check_l1 -- "STALE" --> respond_stale + check_l1 -. "MISS (cold start, blocks)" .-> check_l2 + + respond_stale -. "spawns task" .-> check_l2 + + check_l2 -- "FRESH HIT" --> done_l2["Update L1 cache"] + check_l2 -. "STALE" .-> acquire_lock + check_l2 -. "MISS" .-> acquire_lock + + acquire_lock -- "LOCK ACQUIRED" --> api + acquire_lock -. "LOCK HELD + stale exists" .-> serve_stale["Return stale data"] + acquire_lock -. "LOCK HELD + no data" .-> retry["Wait, retry Redis, or raise"] + + api --> write --> done_api["Update L1 cache"] + + style req fill:#2c3e50,stroke:#1a252f,color:#ecf0f1,stroke-width:2px + style check_l1 fill:#2980b9,stroke:#1f6da0,color:#fff,stroke-width:2px + style check_l2 fill:#d35400,stroke:#a04000,color:#fff,stroke-width:2px + style acquire_lock fill:#e67e22,stroke:#bf6516,color:#fff,stroke-width:2px + style api fill:#1e8449,stroke:#145a32,color:#fff,stroke-width:2px + style write fill:#1e8449,stroke:#145a32,color:#fff,stroke-width:2px + style respond_fresh fill:#27ae60,stroke:#1e8449,color:#fff,stroke-width:2px + style respond_stale fill:#27ae60,stroke:#1e8449,color:#fff,stroke-width:2px + style serve_stale fill:#f4d03f,stroke:#d4ac0f,color:#333 + style retry fill:#e74c3c,stroke:#c0392b,color:#fff + style done_l2 fill:#27ae60,stroke:#1e8449,color:#fff + style done_api fill:#27ae60,stroke:#1e8449,color:#fff + style L1 fill:#eaf2f8,stroke:#2980b9,stroke-width:2px,color:#2c3e50 + style L2 fill:#fef5e7,stroke:#d35400,stroke-width:2px,color:#2c3e50 + style bg fill:#f4f6f7,stroke:#95a5a6,stroke-width:2px,stroke-dasharray: 8 4,color:#2c3e50 +``` + +Two layers of caching sit in front of the Corpus GraphQL API: + +- **L1 (in-memory SWR)** — per-pod. Serves requests immediately. On stale, spawns a background task to revalidate. +- **L2 (Redis)** — shared across all pods. The background task checks Redis before hitting the API. + +When L2 is stale, one pod acquires a distributed lock, fetches from the API, and writes to Redis. Other pods serve stale data until the winner finishes. + +On cold start (no L1 or L2 data), the request blocks until data is fetched. All pods may hit the API simultaneously in this case — same as today without the cache. + +## Configuration + +Config section: `[default.curated_recommendations.corpus_cache]` in `merino/configs/default.toml`. + +Key settings: +- `cache` — `"redis"` to enable, `"none"` to disable (default: disabled) +- `soft_ttl_sec` — when a cached entry is considered stale and triggers revalidation +- `hard_ttl_sec` — when Redis evicts the key entirely (safety net) +- `lock_ttl_sec` — auto-release timeout if the lock holder crashes +- `key_prefix` — bump the version on schema changes to avoid deserialization errors + +Env var override pattern: `MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__CACHE=redis` + +Uses the shared Redis cluster (`[default.redis]`). No separate instance needed. + +## Design decisions + +| Decision | Choice | Why | +|---|---|---| +| Cache layer | Redis L2 behind existing in-memory L1 | Keeps per-pod latency low, Redis only consulted on L1 miss | +| Write pattern | Distributed stale-while-revalidate | One pod revalidates, others serve stale. Avoids thundering herd | +| Lock mechanism | `SET NX EX` with TTL | Simple, self-expiring. Worst case on timeout: one extra API call | +| Cache format | Pydantic model dicts via orjson | Saves CPU across pods vs re-parsing raw GraphQL | +| Failure mode | All Redis errors fall through to API | Redis is an optimization, never a requirement | + +## Rollout + +1. Deploy with cache disabled (no behavior change) +2. Enable in staging +3. Monitor metrics, validate API call reduction +4. Enable in production + +## Key files + +- `merino/curated_recommendations/corpus_backends/redis_cache.py` — cache logic +- `merino/curated_recommendations/__init__.py` — wiring (`_init_corpus_cache`) +- `merino/configs/default.toml` — config section with defaults and documentation diff --git a/merino/cache/none.py b/merino/cache/none.py index 0690672ed..8943899d5 100644 --- a/merino/cache/none.py +++ b/merino/cache/none.py @@ -35,3 +35,9 @@ async def sismember(self, key: str, value: str) -> bool: # noqa: D102 async def scard(self, key: str) -> int: # noqa: D102 return 0 + + async def set_nx(self, key: str, ttl_sec: int) -> bool: # noqa: D102 + return True + + async def delete(self, key: str) -> None: # noqa: D102 + pass diff --git a/merino/cache/protocol.py b/merino/cache/protocol.py index 61ab15eb7..02cb513e3 100644 --- a/merino/cache/protocol.py +++ b/merino/cache/protocol.py @@ -78,3 +78,22 @@ async def sismember(self, key: str, value: str) -> bool: async def scard(self, key: str) -> int: """Get the number of members in a Redis set.""" ... + + async def set_nx(self, key: str, ttl_sec: int) -> bool: # pragma: no cover + """Set the key only if it does not already exist, with a TTL in seconds. + + Returns: + True if the key was set, False if it already existed. + + Raises: + - `CacheAdapterError` for cache backend errors. + """ + ... + + async def delete(self, key: str) -> None: # pragma: no cover + """Delete a key from the cache. + + Raises: + - `CacheAdapterError` for cache backend errors. + """ + ... diff --git a/merino/cache/redis.py b/merino/cache/redis.py index 858ee9e78..6071ed75d 100644 --- a/merino/cache/redis.py +++ b/merino/cache/redis.py @@ -127,6 +127,31 @@ async def scard(self, key: str) -> int: except RedisError as exc: raise CacheAdapterError(f"Failed to SCARD {key} with error: {exc}") from exc + async def set_nx(self, key: str, ttl_sec: int) -> bool: + """Set the key only if it does not exist, with a TTL in seconds. + + Returns: + True if the key was set, False if it already existed. + + Raises: + - `CacheAdapterError` if Redis returns an error. + """ + try: + return bool(await self.primary.set(key, b"1", nx=True, ex=ttl_sec)) + except RedisError as exc: + raise CacheAdapterError(f"Failed to SETNX `{repr(key)}` with error: `{exc}`") from exc + + async def delete(self, key: str) -> None: + """Delete a key from Redis. + + Raises: + - `CacheAdapterError` if Redis returns an error. + """ + try: + await self.primary.delete(key) + except RedisError as exc: + raise CacheAdapterError(f"Failed to DELETE `{repr(key)}` with error: `{exc}`") from exc + async def close(self) -> None: """Close the Redis connection.""" if self.primary is self.replica: diff --git a/merino/configs/default.toml b/merino/configs/default.toml index 73281ad8a..392090cd8 100644 --- a/merino/configs/default.toml +++ b/merino/configs/default.toml @@ -1061,6 +1061,28 @@ blob_name = "contextual_ts/cohort_model_v2.safetensors" cron_interval_seconds = 600 +[default.curated_recommendations.corpus_cache] +# MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__CACHE +# Shared corpus cache backend. "redis" enables Redis as an L2 cache +# between in-memory SWR (L1) and the Corpus GraphQL API. "none" disables it. +cache = "none" + +# MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__SOFT_TTL_SEC +# Soft TTL in seconds. After this, one pod revalidates while others serve stale data. +soft_ttl_sec = 120 + +# MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__HARD_TTL_SEC +# Hard TTL in seconds. Redis evicts the key after this. Should be much longer than soft TTL. +hard_ttl_sec = 600 + +# MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__LOCK_TTL_SEC +# Distributed lock TTL in seconds. Auto-releases if the lock holder crashes. +lock_ttl_sec = 30 + +# MERINO__CURATED_RECOMMENDATIONS__CORPUS_CACHE__KEY_PREFIX +# Prefix for all Redis keys. Bump the version on schema changes. +key_prefix = "curated:v1" + [default.curated_recommendations.corpus_api] # MERINO__CURATED_RECOMMENDATIONS__CORPUS_API__RETRY_COUNT diff --git a/merino/curated_recommendations/__init__.py b/merino/curated_recommendations/__init__.py index 8b2954521..ae58fb03f 100644 --- a/merino/curated_recommendations/__init__.py +++ b/merino/curated_recommendations/__init__.py @@ -6,7 +6,17 @@ import logging import random +from merino.cache.redis import RedisAdapter, create_redis_clients from merino.configs import settings +from merino.curated_recommendations.corpus_backends.protocol import ( + ScheduledSurfaceProtocol, + SectionsProtocol, +) +from merino.curated_recommendations.corpus_backends.redis_cache import ( + CorpusCacheConfig, + RedisCachedScheduledSurface, + RedisCachedSections, +) from merino.curated_recommendations.corpus_backends.scheduled_surface_backend import ( ScheduledSurfaceBackend, CorpusApiGraphConfig, @@ -169,6 +179,46 @@ def init_ml_cohort_model_backend() -> CohortModelBackend: return EmptyCohortModel() +def _init_corpus_cache( + scheduled_surface_backend: ScheduledSurfaceProtocol, + sections_backend: SectionsProtocol, +) -> tuple[ScheduledSurfaceProtocol, SectionsProtocol, RedisAdapter | None]: + """Optionally wrap corpus backends with a Redis L2 cache layer. + + Returns the backends (possibly wrapped) and the Redis adapter (if created). + The caller owns the adapter and is responsible for closing it on shutdown. + """ + cache_settings = settings.curated_recommendations.corpus_cache + if cache_settings.cache != "redis": + return scheduled_surface_backend, sections_backend, None + + try: + logger.info("Initializing Redis L2 cache for corpus backends") + adapter = RedisAdapter( + *create_redis_clients( + primary=settings.redis.server, + replica=settings.redis.replica, + max_connections=settings.redis.max_connections, + socket_connect_timeout=settings.redis.socket_connect_timeout_sec, + socket_timeout=settings.redis.socket_timeout_sec, + ) + ) + config = CorpusCacheConfig( + soft_ttl_sec=cache_settings.soft_ttl_sec, + hard_ttl_sec=cache_settings.hard_ttl_sec, + lock_ttl_sec=cache_settings.lock_ttl_sec, + key_prefix=cache_settings.key_prefix, + ) + return ( + RedisCachedScheduledSurface(scheduled_surface_backend, adapter, config), + RedisCachedSections(sections_backend, adapter, config), + adapter, + ) + except Exception as e: + logger.error("Failed to initialize Redis corpus cache, proceeding without it: %s", e) + return scheduled_surface_backend, sections_backend, None + + def init_provider() -> None: """Initialize the curated recommendations' provider.""" global _provider @@ -179,20 +229,24 @@ def init_provider() -> None: ml_recommendations_backend = init_ml_recommendations_backend() cohort_model_backend = init_ml_cohort_model_backend() - scheduled_surface_backend = ScheduledSurfaceBackend( + scheduled_surface_backend: ScheduledSurfaceProtocol = ScheduledSurfaceBackend( http_client=create_http_client(base_url=""), graph_config=CorpusApiGraphConfig(), metrics_client=get_metrics_client(), manifest_provider=get_manifest_provider(), ) - sections_backend = SectionsBackend( + sections_backend: SectionsProtocol = SectionsBackend( http_client=create_http_client(base_url=""), graph_config=CorpusApiGraphConfig(), metrics_client=get_metrics_client(), manifest_provider=get_manifest_provider(), ) + scheduled_surface_backend, sections_backend, cache_adapter = _init_corpus_cache( + scheduled_surface_backend, sections_backend + ) + _provider = CuratedRecommendationsProvider( scheduled_surface_backend=scheduled_surface_backend, engagement_backend=engagement_backend, @@ -201,6 +255,7 @@ def init_provider() -> None: local_model_backend=local_model_backend, ml_recommendations_backend=ml_recommendations_backend, cohort_model_backend=cohort_model_backend, + cache_adapter=cache_adapter, ) _legacy_provider = LegacyCuratedRecommendationsProvider() @@ -215,3 +270,11 @@ def get_legacy_provider() -> LegacyCuratedRecommendationsProvider: """Return the legacy curated recommendations provider""" global _legacy_provider return _legacy_provider + + +async def shutdown() -> None: + """Clean up resources used by curated recommendations.""" + try: + await _provider.shutdown() + except NameError: + pass diff --git a/merino/curated_recommendations/corpus_backends/redis_cache.py b/merino/curated_recommendations/corpus_backends/redis_cache.py new file mode 100644 index 000000000..bbcb18fae --- /dev/null +++ b/merino/curated_recommendations/corpus_backends/redis_cache.py @@ -0,0 +1,311 @@ +"""Redis L2 cache for corpus backends with distributed stale-while-revalidate.""" + +import asyncio +import logging +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Awaitable, Callable, TypeVar + +import orjson + +from merino.cache.protocol import CacheAdapter +from merino.curated_recommendations.corpus_backends.protocol import ( + CorpusItem, + CorpusSection, + ScheduledSurfaceProtocol, + SectionsProtocol, + SurfaceId, +) +from merino.exceptions import BackendError, CacheAdapterError + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass(frozen=True) +class CorpusCacheConfig: + """Configuration for the Redis corpus cache layer. + + Attributes: + soft_ttl_sec: Seconds before a cached entry is considered stale. One pod revalidates + while others continue to serve the stale value. + hard_ttl_sec: Seconds before Redis evicts the key entirely. Safety net. + lock_ttl_sec: Seconds before a distributed revalidation lock auto-expires. + key_prefix: Prefix for all Redis keys. Bump the version on schema changes. + """ + + soft_ttl_sec: int + hard_ttl_sec: int + lock_ttl_sec: int + key_prefix: str + + def __post_init__(self) -> None: + """Validate that TTL values are consistent and positive.""" + if self.soft_ttl_sec <= 0: + raise ValueError(f"soft_ttl_sec ({self.soft_ttl_sec}) must be positive") + if self.hard_ttl_sec <= 0: + raise ValueError(f"hard_ttl_sec ({self.hard_ttl_sec}) must be positive") + if self.lock_ttl_sec <= 0: + raise ValueError(f"lock_ttl_sec ({self.lock_ttl_sec}) must be positive") + if self.hard_ttl_sec <= self.soft_ttl_sec: + raise ValueError( + f"hard_ttl_sec ({self.hard_ttl_sec}) must be greater than " + f"soft_ttl_sec ({self.soft_ttl_sec})" + ) + if self.hard_ttl_sec <= self.lock_ttl_sec: + raise ValueError( + f"hard_ttl_sec ({self.hard_ttl_sec}) must be greater than " + f"lock_ttl_sec ({self.lock_ttl_sec})" + ) + + +def _build_data_key( + config: CorpusCacheConfig, backend_type: str, surface_id: str, *extra: str +) -> str: + """Build the Redis key for cached corpus data.""" + parts = [config.key_prefix, backend_type, surface_id, *extra] + return ":".join(parts) + + +def _build_lock_key( + config: CorpusCacheConfig, backend_type: str, surface_id: str, *extra: str +) -> str: + """Build the Redis key for the distributed revalidation lock.""" + parts = [config.key_prefix, "lock", backend_type, surface_id, *extra] + return ":".join(parts) + + +def _serialize_envelope(data: list[dict], soft_ttl_sec: int) -> bytes: + """Serialize data with an expiration timestamp into a cache envelope.""" + envelope = { + "expires_at": time.time() + soft_ttl_sec, + "data": data, + } + return orjson.dumps(envelope) + + +def _deserialize_envelope(raw: bytes) -> tuple[float, list[dict]]: + """Deserialize a cache envelope, returning (expires_at, data).""" + envelope = orjson.loads(raw) + return envelope["expires_at"], envelope["data"] + + +class _RedisCorpusCache: + """Shared Redis cache logic for corpus backends. + + Implements distributed stale-while-revalidate: when the cached value is stale, + one pod acquires a lock and revalidates while others serve stale data. + """ + + def __init__(self, cache: CacheAdapter, config: CorpusCacheConfig) -> None: + self._cache = cache + self._config = config + + async def get_or_fetch( + self, + backend_type: str, + surface_id: str, + *extra: str, + fetch_fn: Callable[[], Awaitable[list[T]]], + serialize_fn: Callable[[list[T]], list[dict]], + deserialize_fn: Callable[[list[dict]], list[T]], + ) -> list[T]: + """Check Redis, returning cached data or fetching from the backend. + + Args: + backend_type: Type identifier for key namespacing (e.g. "scheduled", "sections"). + surface_id: Surface ID value for the cache key. + *extra: Additional key segments (e.g. days_offset). + fetch_fn: Async callable that fetches fresh data from the backend. + serialize_fn: Converts typed models to dicts for Redis storage. + deserialize_fn: Converts dicts from Redis back to typed models. + """ + data_key = _build_data_key(self._config, backend_type, surface_id, *extra) + lock_key = _build_lock_key(self._config, backend_type, surface_id, *extra) + # Try reading from Redis + cached = await self._redis_get(data_key) + if cached is not None: + try: + expires_at, items_data = cached + is_fresh = time.time() < expires_at + except TypeError: + # expires_at is not numeric (corrupted envelope) + logger.warning( + "Invalid expires_at in corpus cache key %s", data_key, exc_info=True + ) + is_fresh = False + items_data = None + + if is_fresh and items_data is not None: + try: + return deserialize_fn(items_data) + except Exception: + logger.warning( + "Deserialization failed for corpus cache key %s", + data_key, + exc_info=True, + ) + # Fall through to revalidation/fetch below + elif items_data is not None: + # Stale — try to revalidate + if await self._try_acquire_lock(lock_key): + return await self._revalidate(data_key, lock_key, fetch_fn, serialize_fn) + try: + return deserialize_fn(items_data) + except Exception: + logger.warning( + "Deserialization of stale data failed for corpus cache key %s", + data_key, + exc_info=True, + ) + + # Cache miss — try to acquire lock and fetch + if await self._try_acquire_lock(lock_key): + return await self._revalidate(data_key, lock_key, fetch_fn, serialize_fn) + # Another pod is populating; wait briefly then retry Redis + await asyncio.sleep(0.1) + cached = await self._redis_get(data_key) + if cached is not None: + _, items_data = cached + if items_data is not None: + try: + return deserialize_fn(items_data) + except Exception: + logger.warning( + "Deserialization failed on retry for corpus cache key %s", + data_key, + exc_info=True, + ) + raise BackendError(f"Cache miss and lock held for {data_key}") + + async def _revalidate( + self, + data_key: str, + lock_key: str, + fetch_fn: Callable[[], Awaitable[list[T]]], + serialize_fn: Callable[[list[T]], list[dict]], + ) -> list[T]: + """Fetch from the backend, write to Redis, and release the lock. + + Uses try/finally to ensure the lock is released even on cancellation + (asyncio.CancelledError is a BaseException, not caught by except Exception). + """ + try: + items = await fetch_fn() + # Cache write is best-effort: don't lose fetched items on serialize or write failure. + try: + serialized = serialize_fn(items) + except Exception: + logger.warning( + "Serialization failed for corpus cache key %s", data_key, exc_info=True + ) + else: + await self._redis_set(data_key, serialized) + return items + finally: + await self._release_lock(lock_key) + + async def _redis_get(self, key: str) -> tuple[float, list[dict]] | None: + """Read and deserialize from Redis. Returns None on any error.""" + try: + raw = await self._cache.get(key) + if raw is None: + return None + return _deserialize_envelope(raw) + except CacheAdapterError: + logger.warning("Redis read error for corpus cache key %s", key, exc_info=True) + return None + except (orjson.JSONDecodeError, KeyError, TypeError): + logger.warning( + "Redis deserialization error for corpus cache key %s", + key, + exc_info=True, + ) + return None + + async def _redis_set(self, key: str, data: list[dict]) -> None: + """Serialize and write to Redis. Logs on error without raising.""" + try: + value = _serialize_envelope(data, self._config.soft_ttl_sec) + await self._cache.set(key, value, ttl=timedelta(seconds=self._config.hard_ttl_sec)) + except Exception: + logger.warning("Redis write error for corpus cache key %s", key, exc_info=True) + + async def _try_acquire_lock(self, lock_key: str) -> bool: + """Attempt to acquire a distributed lock via SET NX EX.""" + try: + return await self._cache.set_nx(lock_key, self._config.lock_ttl_sec) + except CacheAdapterError: + logger.warning("Redis lock acquire error for %s", lock_key, exc_info=True) + return False + + async def _release_lock(self, lock_key: str) -> None: + """Release the distributed lock by deleting the key. + + Note: This uses unconditional DELETE rather than owner-aware release + (compare-and-delete via Lua script). If revalidation exceeds lock_ttl_sec + (30s default), another pod's lock could be deleted. The consequence is at + most one extra redundant API call, not a stampede, because the SWR pattern + ensures other pods serve stale/cached data regardless of lock state. + """ + try: + await self._cache.delete(lock_key) + except CacheAdapterError: + logger.warning("Redis lock release error for %s", lock_key, exc_info=True) + + +class RedisCachedScheduledSurface(ScheduledSurfaceProtocol): + """Redis L2 cache wrapper for ScheduledSurfaceBackend. + + Checks Redis before hitting the Corpus API. Uses distributed SWR: + when the cached value is stale, one pod acquires a lock and revalidates + while others continue to serve stale data. + """ + + def __init__( + self, + backend: ScheduledSurfaceProtocol, + cache: CacheAdapter, + config: CorpusCacheConfig, + ) -> None: + self._backend = backend + self._redis_cache = _RedisCorpusCache(cache, config) + + async def fetch(self, surface_id: SurfaceId, days_offset: int = 0) -> list[CorpusItem]: + """Fetch corpus items, checking Redis L2 cache first.""" + return await self._redis_cache.get_or_fetch( + "scheduled", + surface_id.value, + str(days_offset), + fetch_fn=lambda: self._backend.fetch(surface_id, days_offset), + serialize_fn=lambda items: [item.model_dump(mode="json") for item in items], + deserialize_fn=lambda data: [CorpusItem.model_validate(d) for d in data], + ) + + +class RedisCachedSections(SectionsProtocol): + """Redis L2 cache wrapper for SectionsBackend. + + Same distributed SWR pattern as RedisCachedScheduledSurface. + """ + + def __init__( + self, + backend: SectionsProtocol, + cache: CacheAdapter, + config: CorpusCacheConfig, + ) -> None: + self._backend = backend + self._redis_cache = _RedisCorpusCache(cache, config) + + async def fetch(self, surface_id: SurfaceId) -> list[CorpusSection]: + """Fetch corpus sections, checking Redis L2 cache first.""" + return await self._redis_cache.get_or_fetch( + "sections", + surface_id.value, + fetch_fn=lambda: self._backend.fetch(surface_id), + serialize_fn=lambda sections: [s.model_dump(mode="json") for s in sections], + deserialize_fn=lambda data: [CorpusSection.model_validate(d) for d in data], + ) diff --git a/merino/curated_recommendations/provider.py b/merino/curated_recommendations/provider.py index 8dee13102..b864c88d2 100644 --- a/merino/curated_recommendations/provider.py +++ b/merino/curated_recommendations/provider.py @@ -4,6 +4,7 @@ from typing import cast +from merino.cache.protocol import CacheAdapter from merino.curated_recommendations import LocalModelBackend, MLRecsBackend from merino.curated_recommendations.ml_backends.protocol import ( LOCAL_MODEL_MODEL_ID_KEY, @@ -65,6 +66,7 @@ def __init__( local_model_backend: LocalModelBackend, ml_recommendations_backend: MLRecsBackend, cohort_model_backend: CohortModelBackend, + cache_adapter: CacheAdapter | None = None, ) -> None: self.scheduled_surface_backend = scheduled_surface_backend self.engagement_backend = engagement_backend @@ -73,6 +75,12 @@ def __init__( self.local_model_backend = local_model_backend self.ml_recommendations_backend = ml_recommendations_backend self.cohort_model_backend = cohort_model_backend + self._cache_adapter = cache_adapter + + async def shutdown(self) -> None: + """Close resources owned by this provider.""" + if self._cache_adapter is not None: + await self._cache_adapter.close() @staticmethod def is_sections_experiment( diff --git a/merino/main.py b/merino/main.py index 558f7b548..f177a9429 100644 --- a/merino/main.py +++ b/merino/main.py @@ -51,6 +51,7 @@ async def lifespan(app: FastAPI): governance.shutdown() # Shut down providers and clean up. await suggest.shutdown_providers() + await curated_recommendations.shutdown() await get_metrics_client().close() diff --git a/tests/unit/curated_recommendations/corpus_backends/test_redis_cache.py b/tests/unit/curated_recommendations/corpus_backends/test_redis_cache.py new file mode 100644 index 000000000..80fa43a93 --- /dev/null +++ b/tests/unit/curated_recommendations/corpus_backends/test_redis_cache.py @@ -0,0 +1,610 @@ +"""Unit tests for the Redis L2 corpus cache layer.""" + +import asyncio +import time +from unittest.mock import AsyncMock + +import orjson +import pytest + +from merino.curated_recommendations.corpus_backends.protocol import ( + CorpusSection, + CreateSource, + SurfaceId, +) +from merino.curated_recommendations.corpus_backends.redis_cache import ( + CorpusCacheConfig, + RedisCachedScheduledSurface, + RedisCachedSections, + _RedisCorpusCache, + _build_data_key, + _build_lock_key, + _deserialize_envelope, + _serialize_envelope, +) +from merino.exceptions import BackendError, CacheAdapterError +from tests.unit.curated_recommendations.test_sections import generate_corpus_item + +SURFACE_ID = SurfaceId.NEW_TAB_EN_US + +CONFIG = CorpusCacheConfig( + soft_ttl_sec=120, + hard_ttl_sec=600, + lock_ttl_sec=30, + key_prefix="curated:v1", +) + + +class TestCorpusCacheConfig: + """Tests for CorpusCacheConfig validation.""" + + def test_valid_config(self) -> None: + """Accept valid TTL ordering.""" + config = CorpusCacheConfig( + soft_ttl_sec=120, hard_ttl_sec=600, lock_ttl_sec=30, key_prefix="test" + ) + assert config.soft_ttl_sec == 120 + + def test_hard_ttl_must_exceed_soft_ttl(self) -> None: + """Reject hard_ttl_sec <= soft_ttl_sec.""" + with pytest.raises(ValueError, match="hard_ttl_sec.*must be greater than.*soft_ttl_sec"): + CorpusCacheConfig( + soft_ttl_sec=600, hard_ttl_sec=120, lock_ttl_sec=30, key_prefix="test" + ) + + def test_hard_ttl_must_exceed_lock_ttl(self) -> None: + """Reject hard_ttl_sec <= lock_ttl_sec.""" + with pytest.raises(ValueError, match="hard_ttl_sec.*must be greater than.*lock_ttl_sec"): + CorpusCacheConfig(soft_ttl_sec=10, hard_ttl_sec=20, lock_ttl_sec=30, key_prefix="test") + + @pytest.mark.parametrize( + "soft,hard,lock", + [ + (0, 600, 30), + (120, 0, 30), + (120, 600, 0), + (-1, 600, 30), + ], + ids=["zero_soft", "zero_hard", "zero_lock", "negative_soft"], + ) + def test_ttl_values_must_be_positive(self, soft: int, hard: int, lock: int) -> None: + """Reject zero or negative TTL values.""" + with pytest.raises(ValueError, match="must be positive"): + CorpusCacheConfig( + soft_ttl_sec=soft, hard_ttl_sec=hard, lock_ttl_sec=lock, key_prefix="test" + ) + + +def _make_corpus_section() -> CorpusSection: + """Create a CorpusSection with sensible defaults for testing.""" + return CorpusSection( + sectionItems=[generate_corpus_item()], + title="Test Section", + externalId="test-section", + createSource=CreateSource.ML, + ) + + +def _make_fresh_envelope(items_data: list[dict], soft_ttl_sec: int = 120) -> bytes: + """Create a serialized cache envelope that is still fresh.""" + return _serialize_envelope(items_data, soft_ttl_sec) + + +def _make_stale_envelope(items_data: list[dict]) -> bytes: + """Create a serialized cache envelope that has already expired.""" + envelope = { + "expires_at": time.time() - 10, + "data": items_data, + } + return orjson.dumps(envelope) + + +class TestKeyBuilders: + """Tests for Redis key construction functions.""" + + def test_build_data_key_scheduled(self) -> None: + """Build a data key for a scheduled surface with days_offset.""" + key = _build_data_key(CONFIG, "scheduled", "NEW_TAB_EN_US", "0") + assert key == "curated:v1:scheduled:NEW_TAB_EN_US:0" + + def test_build_data_key_sections(self) -> None: + """Build a data key for sections.""" + key = _build_data_key(CONFIG, "sections", "NEW_TAB_EN_US") + assert key == "curated:v1:sections:NEW_TAB_EN_US" + + def test_build_lock_key(self) -> None: + """Build a lock key with 'lock' segment inserted.""" + key = _build_lock_key(CONFIG, "scheduled", "NEW_TAB_EN_US", "0") + assert key == "curated:v1:lock:scheduled:NEW_TAB_EN_US:0" + + +class TestEnvelope: + """Tests for envelope serialization and deserialization.""" + + def test_roundtrip(self) -> None: + """Serialize and deserialize an envelope.""" + data = [{"corpusItemId": "abc", "title": "Hello"}] + raw = _serialize_envelope(data, soft_ttl_sec=120) + expires_at, deserialized = _deserialize_envelope(raw) + assert deserialized == data + assert expires_at > time.time() + + def test_deserialize_invalid_json(self) -> None: + """Raise on invalid JSON bytes.""" + with pytest.raises(orjson.JSONDecodeError): + _deserialize_envelope(b"not json") + + +class TestRedisCorpusCache: + """Tests for the shared _RedisCorpusCache logic.""" + + def setup_method(self) -> None: + """Set up mock cache adapter and helper functions.""" + self.mock_cache = AsyncMock() + self.redis_cache = _RedisCorpusCache(self.mock_cache, CONFIG) + self.fetch_fn = AsyncMock(return_value=["item1", "item2"]) + self.serialize_fn = lambda items: [{"v": i} for i in items] + self.deserialize_fn = lambda data: [d["v"] for d in data] + + @pytest.mark.asyncio + async def test_fresh_hit_returns_cached_data(self) -> None: + """Return deserialized data on a fresh Redis hit without calling fetch_fn.""" + items_data = [{"v": "item1"}, {"v": "item2"}] + self.mock_cache.get.return_value = _make_fresh_envelope(items_data) + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_not_called() + self.mock_cache.set_nx.assert_not_called() + + @pytest.mark.asyncio + async def test_stale_hit_lock_winner_revalidates(self) -> None: + """Revalidate when stale and lock is acquired.""" + items_data = [{"v": "old"}] + self.mock_cache.get.return_value = _make_stale_envelope(items_data) + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + self.mock_cache.set.assert_called_once() + self.mock_cache.delete.assert_called_once_with("curated:v1:lock:test:surface") + + @pytest.mark.asyncio + async def test_stale_hit_lock_loser_returns_stale(self) -> None: + """Return stale data when another pod holds the lock.""" + items_data = [{"v": "stale"}] + self.mock_cache.get.return_value = _make_stale_envelope(items_data) + self.mock_cache.set_nx.return_value = False + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["stale"] + self.fetch_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_lock_winner_fetches(self) -> None: + """Fetch from backend on cache miss when lock is acquired.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_cache_miss_lock_loser_retries_and_succeeds(self) -> None: + """Wait and retry Redis when cache misses and lock is held by another pod.""" + items_data = [{"v": "item1"}, {"v": "item2"}] + # First get returns None (miss), second get returns data (written by lock winner) + self.mock_cache.get.side_effect = [None, _make_fresh_envelope(items_data)] + self.mock_cache.set_nx.return_value = False + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_not_called() + assert self.mock_cache.get.call_count == 2 + + @pytest.mark.asyncio + async def test_cache_miss_lock_loser_retries_and_raises(self) -> None: + """Raise BackendError when retry still finds no data after waiting.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = False + + with pytest.raises(BackendError, match="Cache miss and lock held"): + await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + self.fetch_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_lock_loser_retry_deserialize_error_raises(self) -> None: + """Raise BackendError when retry data exists but deserialization fails.""" + items_data = [{"v": "item1"}] + self.mock_cache.get.side_effect = [None, _make_fresh_envelope(items_data)] + self.mock_cache.set_nx.return_value = False + + def bad_deserialize(data: list[dict]) -> list: + raise ValueError("schema changed") + + with pytest.raises(BackendError, match="Cache miss and lock held"): + await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=bad_deserialize, + ) + + self.fetch_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_redis_read_error_falls_through(self) -> None: + """Fall through to backend when Redis read fails.""" + self.mock_cache.get.side_effect = CacheAdapterError("connection refused") + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_redis_write_error_does_not_raise(self) -> None: + """Continue normally when Redis write fails after fetching.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.mock_cache.set.side_effect = CacheAdapterError("write failed") + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + # Cache write was attempted (and failed), but lock should still be released + self.mock_cache.set.assert_called_once() + self.mock_cache.delete.assert_called_once_with("curated:v1:lock:test:surface") + + @pytest.mark.asyncio + async def test_deserialization_error_falls_through(self) -> None: + """Treat corrupted Redis data as a cache miss.""" + self.mock_cache.get.return_value = b"not valid json" + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_lock_acquire_error_on_miss_retries_then_raises(self) -> None: + """Raise BackendError when lock acquisition fails and retry finds no data.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.side_effect = CacheAdapterError("lock error") + + with pytest.raises(BackendError, match="Cache miss and lock held"): + await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + self.fetch_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_lock_acquire_error_returns_stale_on_stale_hit(self) -> None: + """Return stale data when lock acquisition fails on stale hit.""" + items_data = [{"v": "stale"}] + self.mock_cache.get.return_value = _make_stale_envelope(items_data) + self.mock_cache.set_nx.side_effect = CacheAdapterError("lock error") + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["stale"] + self.fetch_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_backend_error_releases_lock(self) -> None: + """Release the lock when the backend raises an error.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.fetch_fn.side_effect = Exception("API down") + + with pytest.raises(Exception, match="API down"): + await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + self.mock_cache.delete.assert_called_once_with("curated:v1:lock:test:surface") + + @pytest.mark.asyncio + async def test_serialize_error_returns_items_and_releases_lock(self) -> None: + """Return fetched items even when serialize_fn fails (best-effort cache write).""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + + def bad_serialize(items: list) -> list[dict]: + raise TypeError("cannot serialize") + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=bad_serialize, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.mock_cache.delete.assert_called_once_with("curated:v1:lock:test:surface") + + @pytest.mark.asyncio + async def test_deserialize_fn_error_on_fresh_hit_falls_through(self) -> None: + """Fall through to backend when deserialize_fn raises on fresh cached data.""" + items_data = [{"v": "item1"}] + self.mock_cache.get.return_value = _make_fresh_envelope(items_data) + self.mock_cache.set_nx.return_value = True + + def bad_deserialize(data: list[dict]) -> list: + raise ValueError("validation error") + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=bad_deserialize, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_stale_hit_lock_loser_deserialize_error_falls_through(self) -> None: + """Fall through to backend when stale data can't be deserialized and lock is held.""" + items_data = [{"v": "stale"}] + self.mock_cache.get.return_value = _make_stale_envelope(items_data) + # First set_nx returns False (lock loser), second returns True (lock released) + self.mock_cache.set_nx.side_effect = [False, True] + + def bad_deserialize(data: list[dict]) -> list: + raise ValueError("schema changed") + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=bad_deserialize, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_non_numeric_expires_at_treated_as_miss(self) -> None: + """Treat an envelope with non-numeric expires_at as a cache miss.""" + self.mock_cache.get.return_value = orjson.dumps( + {"expires_at": "not-a-number", "data": [{"v": "corrupted"}]} + ) + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_cancelled_error_releases_lock(self) -> None: + """Release the lock even when the task is cancelled (BaseException).""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + + async def cancelled_fetch() -> list: + raise asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=cancelled_fetch, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + self.mock_cache.delete.assert_called_once_with("curated:v1:lock:test:surface") + + @pytest.mark.asyncio + async def test_malformed_envelope_missing_key_falls_through(self) -> None: + """Treat a JSON blob missing required keys as a cache miss.""" + self.mock_cache.get.return_value = orjson.dumps({"wrong_key": 123}) + self.mock_cache.set_nx.return_value = True + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == ["item1", "item2"] + self.fetch_fn.assert_called_once() + + @pytest.mark.asyncio + async def test_empty_list_is_cached(self) -> None: + """Cache and return an empty list from the backend.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.fetch_fn.return_value = [] + + result = await self.redis_cache.get_or_fetch( + "test", + "surface", + fetch_fn=self.fetch_fn, + serialize_fn=self.serialize_fn, + deserialize_fn=self.deserialize_fn, + ) + + assert result == [] + self.mock_cache.set.assert_called_once() + + +class TestRedisCachedScheduledSurface: + """Tests for the RedisCachedScheduledSurface wrapper.""" + + def setup_method(self) -> None: + """Set up mock backend and cache.""" + self.mock_backend = AsyncMock() + self.mock_cache = AsyncMock() + self.wrapper = RedisCachedScheduledSurface(self.mock_backend, self.mock_cache, CONFIG) + + @pytest.mark.asyncio + async def test_fresh_hit_returns_deserialized_items(self) -> None: + """Return CorpusItem list from a fresh Redis cache hit.""" + item = generate_corpus_item() + items_data = [item.model_dump(mode="json")] + self.mock_cache.get.return_value = _make_fresh_envelope(items_data) + + result = await self.wrapper.fetch(SURFACE_ID, days_offset=0) + + assert len(result) == 1 + assert result[0].corpusItemId == "id" + self.mock_backend.fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_delegates_to_backend(self) -> None: + """Delegate to the wrapped backend on cache miss.""" + item = generate_corpus_item() + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.mock_backend.fetch.return_value = [item] + + result = await self.wrapper.fetch(SURFACE_ID) + + assert len(result) == 1 + assert result[0].corpusItemId == "id" + self.mock_backend.fetch.assert_called_once_with(SURFACE_ID, 0) + + @pytest.mark.asyncio + async def test_days_offset_included_in_key(self) -> None: + """Include days_offset in the Redis key to differentiate cache entries.""" + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.mock_backend.fetch.return_value = [generate_corpus_item()] + + await self.wrapper.fetch(SURFACE_ID, days_offset=-1) + + # Verify the key includes the days_offset + set_call = self.mock_cache.set.call_args + key = set_call[0][0] + assert ":-1" in key + + +class TestRedisCachedSections: + """Tests for the RedisCachedSections wrapper.""" + + def setup_method(self) -> None: + """Set up mock backend and cache.""" + self.mock_backend = AsyncMock() + self.mock_cache = AsyncMock() + self.wrapper = RedisCachedSections(self.mock_backend, self.mock_cache, CONFIG) + + @pytest.mark.asyncio + async def test_fresh_hit_returns_deserialized_sections(self) -> None: + """Return CorpusSection list from a fresh Redis cache hit.""" + section = _make_corpus_section() + sections_data = [section.model_dump(mode="json")] + self.mock_cache.get.return_value = _make_fresh_envelope(sections_data) + + result = await self.wrapper.fetch(SURFACE_ID) + + assert len(result) == 1 + assert result[0].title == "Test Section" + self.mock_backend.fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_delegates_to_backend(self) -> None: + """Delegate to the wrapped backend on cache miss.""" + section = _make_corpus_section() + self.mock_cache.get.return_value = None + self.mock_cache.set_nx.return_value = True + self.mock_backend.fetch.return_value = [section] + + result = await self.wrapper.fetch(SURFACE_ID) + + assert len(result) == 1 + assert result[0].title == "Test Section" + self.mock_backend.fetch.assert_called_once_with(SURFACE_ID) diff --git a/tests/unit/utils/test_cache_redis.py b/tests/unit/utils/test_cache_redis.py index f1a99ca52..9282f5a26 100644 --- a/tests/unit/utils/test_cache_redis.py +++ b/tests/unit/utils/test_cache_redis.py @@ -2,14 +2,16 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. -"""Unit tests for the cron.py module.""" +"""Unit tests for the Redis cache adapter.""" import pytest -from pytest_mock import MockerFixture +from unittest.mock import AsyncMock -from redis.asyncio import Redis +from redis.asyncio import Redis, RedisError +from pytest_mock import MockerFixture from merino.cache.redis import create_redis_clients, RedisAdapter +from merino.exceptions import CacheAdapterError @pytest.mark.asyncio @@ -32,3 +34,65 @@ async def test_adapter_in_standalone_mode(mocker: MockerFixture) -> None: await adapter.close() spy.assert_called_once() + + +class TestSetNx: + """Tests for RedisAdapter.set_nx.""" + + @pytest.mark.asyncio + async def test_returns_true_when_key_set(self) -> None: + """Return True when the key was newly created.""" + mock_primary = AsyncMock() + mock_primary.set.return_value = True + adapter = RedisAdapter(mock_primary) + + result = await adapter.set_nx("lock:key", 30) + + assert result is True + mock_primary.set.assert_called_once_with("lock:key", b"1", nx=True, ex=30) + + @pytest.mark.asyncio + async def test_returns_false_when_key_exists(self) -> None: + """Return False when the key already exists (Redis returns None).""" + mock_primary = AsyncMock() + mock_primary.set.return_value = None + adapter = RedisAdapter(mock_primary) + + result = await adapter.set_nx("lock:key", 30) + + assert result is False + mock_primary.set.assert_called_once_with("lock:key", b"1", nx=True, ex=30) + + @pytest.mark.asyncio + async def test_raises_cache_adapter_error_on_redis_error(self) -> None: + """Raise CacheAdapterError when Redis returns an error.""" + mock_primary = AsyncMock() + mock_primary.set.side_effect = RedisError("connection lost") + adapter = RedisAdapter(mock_primary) + + with pytest.raises(CacheAdapterError, match="SETNX"): + await adapter.set_nx("lock:key", 30) + + +class TestDelete: + """Tests for RedisAdapter.delete.""" + + @pytest.mark.asyncio + async def test_deletes_key(self) -> None: + """Delete a key from Redis.""" + mock_primary = AsyncMock() + adapter = RedisAdapter(mock_primary) + + await adapter.delete("lock:key") + + mock_primary.delete.assert_called_once_with("lock:key") + + @pytest.mark.asyncio + async def test_raises_cache_adapter_error_on_redis_error(self) -> None: + """Raise CacheAdapterError when Redis returns an error.""" + mock_primary = AsyncMock() + mock_primary.delete.side_effect = RedisError("connection lost") + adapter = RedisAdapter(mock_primary) + + with pytest.raises(CacheAdapterError, match="DELETE"): + await adapter.delete("lock:key")