diff --git a/migrations/versions/359fff0c443a_add_usage_requests_table.py b/migrations/versions/359fff0c443a_add_usage_requests_table.py new file mode 100644 index 0000000..e820bba --- /dev/null +++ b/migrations/versions/359fff0c443a_add_usage_requests_table.py @@ -0,0 +1,140 @@ +"""Add usage_requests table for LiteLLM traffic ingestion. + +Revision ID: 359fff0c443a +Revises: b356c861829f +Create Date: 2025-01-21 21:00:00.000000+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "359fff0c443a" +down_revision: Union[str, None] = "b356c861829f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create usage_requests table with indexes and idempotency constraint.""" + op.create_table( + "usage_requests", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("litellm_call_id", sa.String(255), nullable=False), + sa.Column("request_id", sa.String(255), nullable=True), + sa.Column("key_alias", sa.String(255), nullable=True), + sa.Column("litellm_key_id", sa.String(255), nullable=True), + sa.Column("proxy_key_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("benchmark_session_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("provider", sa.String(255), nullable=True), + sa.Column("provider_route", sa.String(255), nullable=True), + sa.Column("requested_model", sa.String(255), nullable=True), + sa.Column("resolved_model", sa.String(255), nullable=True), + sa.Column("route", sa.String(255), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("latency_ms", sa.Float(), nullable=True), + sa.Column("ttft_ms", sa.Float(), nullable=True), + sa.Column("input_tokens", sa.Integer(), nullable=True), + sa.Column("output_tokens", sa.Integer(), nullable=True), + sa.Column("cached_input_tokens", sa.Integer(), nullable=True), + sa.Column("cache_write_tokens", sa.Integer(), nullable=True), + sa.Column("cost_usd", sa.Float(), nullable=True), + sa.Column("status", sa.String(50), nullable=True), + sa.Column("error_code", sa.String(50), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("cache_hit", sa.Boolean(), nullable=True), + sa.Column("request_metadata", sa.JSON(), nullable=False, server_default="{}"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("NOW()"), + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("litellm_call_id", name="uq_usage_requests_litellm_call_id"), + sa.ForeignKeyConstraint( + ["proxy_key_id"], + ["proxy_keys.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["benchmark_session_id"], + ["sessions.id"], + ondelete="SET NULL", + ), + ) + + # Core lookup indexes + op.create_index( + "ix_usage_requests_started_at", "usage_requests", ["started_at"] + ) + op.create_index( + "ix_usage_requests_key_alias", "usage_requests", ["key_alias"] + ) + op.create_index( + "ix_usage_requests_litellm_key_id", "usage_requests", ["litellm_key_id"] + ) + op.create_index( + "ix_usage_requests_proxy_key_id", "usage_requests", ["proxy_key_id"] + ) + op.create_index( + "ix_usage_requests_benchmark_session_id", + "usage_requests", + ["benchmark_session_id"], + ) + op.create_index( + "ix_usage_requests_requested_model", "usage_requests", ["requested_model"] + ) + op.create_index( + "ix_usage_requests_resolved_model", "usage_requests", ["resolved_model"] + ) + op.create_index( + "ix_usage_requests_provider", "usage_requests", ["provider"] + ) + op.create_index( + "ix_usage_requests_status", "usage_requests", ["status"] + ) + op.create_index( + "ix_usage_requests_error_code", "usage_requests", ["error_code"] + ) + # Composite for time-window + key attribution queries + op.create_index( + "ix_usage_requests_key_alias_started_at", + "usage_requests", + ["key_alias", "started_at"], + ) + # Composite for provider + model queries + op.create_index( + "ix_usage_requests_provider_resolved_model", + "usage_requests", + ["provider", "resolved_model"], + ) + # Composite for session + time queries + op.create_index( + "ix_usage_requests_session_started_at", + "usage_requests", + ["benchmark_session_id", "started_at"], + ) + + +def downgrade() -> None: + """Drop usage_requests table and indexes.""" + op.drop_index("ix_usage_requests_session_started_at", table_name="usage_requests") + op.drop_index("ix_usage_requests_provider_resolved_model", table_name="usage_requests") + op.drop_index("ix_usage_requests_key_alias_started_at", table_name="usage_requests") + op.drop_index("ix_usage_requests_error_code", table_name="usage_requests") + op.drop_index("ix_usage_requests_status", table_name="usage_requests") + op.drop_index("ix_usage_requests_provider", table_name="usage_requests") + op.drop_index("ix_usage_requests_resolved_model", table_name="usage_requests") + op.drop_index("ix_usage_requests_requested_model", table_name="usage_requests") + op.drop_index("ix_usage_requests_benchmark_session_id", table_name="usage_requests") + op.drop_index("ix_usage_requests_proxy_key_id", table_name="usage_requests") + op.drop_index("ix_usage_requests_litellm_key_id", table_name="usage_requests") + op.drop_index("ix_usage_requests_key_alias", table_name="usage_requests") + op.drop_index("ix_usage_requests_started_at", table_name="usage_requests") + op.drop_table("usage_requests") diff --git a/src/benchmark_core/db/models.py b/src/benchmark_core/db/models.py index e072839..2f1fcc7 100644 --- a/src/benchmark_core/db/models.py +++ b/src/benchmark_core/db/models.py @@ -365,6 +365,53 @@ class Request(Base): ) +class UsageRequest(Base): + """One normalized LiteLLM usage record for traffic ingestion. + + Stores request timing, routing, token counts, cost, and error metadata. + No prompt or response content is stored by default. + Designed for sessionless usage tracking with optional benchmark linkage. + """ + + __tablename__ = "usage_requests" + + id: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4) + litellm_call_id: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + request_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + key_alias: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + litellm_key_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + proxy_key_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("proxy_keys.id", ondelete="SET NULL"), nullable=True, index=True + ) + benchmark_session_id: Mapped[uuid.UUID | None] = mapped_column( + ForeignKey("sessions.id", ondelete="SET NULL"), nullable=True, index=True + ) + provider: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + provider_route: Mapped[str | None] = mapped_column(String(255), nullable=True) + requested_model: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + resolved_model: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + route: Mapped[str | None] = mapped_column(String(255), nullable=True) + started_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, index=True + ) + finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + latency_ms: Mapped[float | None] = mapped_column(Float, nullable=True) + ttft_ms: Mapped[float | None] = mapped_column(Float, nullable=True) + input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + output_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + cached_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + cache_write_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) + cost_usd: Mapped[float | None] = mapped_column(Float, nullable=True) + status: Mapped[str | None] = mapped_column(String(50), nullable=True, index=True) + error_code: Mapped[str | None] = mapped_column(String(50), nullable=True, index=True) + error_message: Mapped[str | None] = mapped_column(Text, nullable=True) + cache_hit: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + request_metadata: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(UTC) + ) + + class MetricRollup(Base): """Derived latency, throughput, error, and cache metrics.""" diff --git a/src/benchmark_core/models.py b/src/benchmark_core/models.py index b4d4ba0..f06c135 100644 --- a/src/benchmark_core/models.py +++ b/src/benchmark_core/models.py @@ -159,6 +159,46 @@ class Request(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") +class UsageRequest(BaseModel): + """One normalized LiteLLM usage record for traffic ingestion. + + No prompt or response content is stored by default. + Designed for sessionless usage tracking with optional benchmark linkage. + """ + + usage_request_id: UUID = Field(default_factory=uuid4) + litellm_call_id: str = Field(..., description="Primary LiteLLM call/request ID for idempotency") + request_id: str | None = Field(default=None, description="Alternate request ID") + key_alias: str | None = Field(default=None, description="Denormalized key alias") + litellm_key_id: str | None = Field(default=None, description="LiteLLM internal key ID") + proxy_key_id: UUID | None = Field(default=None, description="Optional FK to proxy_keys") + benchmark_session_id: UUID | None = Field( + default=None, description="Optional FK to benchmark session" + ) + provider: str | None = Field(default=None, description="Provider slug") + provider_route: str | None = Field(default=None, description="Full provider route string") + requested_model: str | None = Field(default=None, description="Client-requested model alias") + resolved_model: str | None = Field(default=None, description="Resolved upstream model name") + route: str | None = Field(default=None, description="API route/endpoint") + started_at: datetime | None = Field(default=None, description="Request start time (UTC)") + finished_at: datetime | None = Field(default=None, description="Request end time (UTC)") + latency_ms: float | None = Field(default=None, description="Total latency in milliseconds") + ttft_ms: float | None = Field(default=None, description="Time to first token in milliseconds") + input_tokens: int | None = Field(default=None, description="Input/prompt token count") + output_tokens: int | None = Field(default=None, description="Output/completion token count") + cached_input_tokens: int | None = Field(default=None, description="Cached input token count") + cache_write_tokens: int | None = Field(default=None, description="Tokens written to cache") + cost_usd: float | None = Field(default=None, description="Spend in USD") + status: str | None = Field(default=None, description="Request status (success/failure/pending)") + error_code: str | None = Field(default=None, description="Error code (e.g. HTTP 429)") + error_message: str | None = Field(default=None, description="Error message") + cache_hit: bool | None = Field(default=None, description="Cache hit flag") + request_metadata: dict[str, Any] = Field( + default_factory=dict, description="Safe metadata (no content)" + ) + created_at: datetime = Field(default_factory=_utc_now) + + class MetricRollup(BaseModel): """Derived latency, throughput, error, and cache metrics.""" diff --git a/src/benchmark_core/repositories/__init__.py b/src/benchmark_core/repositories/__init__.py index 9b0161c..bbf8104 100644 --- a/src/benchmark_core/repositories/__init__.py +++ b/src/benchmark_core/repositories/__init__.py @@ -22,6 +22,7 @@ from benchmark_core.repositories.request_repository import SQLRequestRepository from benchmark_core.repositories.session_repository import SQLSessionRepository from benchmark_core.repositories.task_card_repository import SQLTaskCardRepository +from benchmark_core.repositories.usage_request_repository import SQLUsageRequestRepository from benchmark_core.repositories.variant_repository import SQLVariantRepository from benchmark_core.repositories_abc import ( ArtifactRepository, @@ -53,4 +54,5 @@ "SQLHarnessProfileRepository", "SQLArtifactRepository", "SQLProxyKeyRepository", + "SQLUsageRequestRepository", ] diff --git a/src/benchmark_core/repositories/usage_request_repository.py b/src/benchmark_core/repositories/usage_request_repository.py new file mode 100644 index 0000000..2958e96 --- /dev/null +++ b/src/benchmark_core/repositories/usage_request_repository.py @@ -0,0 +1,399 @@ +"""Repository for UsageRequest entities.""" + +from datetime import datetime +from typing import cast +from uuid import UUID + +from sqlalchemy import delete, func, inspect, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session as SQLAlchemySession + +from benchmark_core.db.models import UsageRequest as UsageRequestORM +from benchmark_core.repositories.base import SQLAlchemyRepository + + +class SQLUsageRequestRepository(SQLAlchemyRepository[UsageRequestORM]): + """SQLAlchemy repository for UsageRequest entities. + + Provides idempotent bulk creation and lookup helpers for usage records. + No prompt or response content is stored by default. + """ + + def __init__(self, db_session: SQLAlchemySession) -> None: + """Initialize the repository. + + Args: + db_session: SQLAlchemy session for database operations. + """ + super().__init__(db_session, UsageRequestORM) + + async def get_by_litellm_call_id(self, litellm_call_id: str) -> UsageRequestORM | None: + """Retrieve a usage request by its LiteLLM call ID. + + Args: + litellm_call_id: The LiteLLM call ID to search for. + + Returns: + The usage request if found, None otherwise. + """ + stmt = select(UsageRequestORM).where(UsageRequestORM.litellm_call_id == litellm_call_id) + return self._session.execute(stmt).scalars().one_or_none() + + async def get_by_request_id(self, request_id: str) -> UsageRequestORM | None: + """Retrieve a usage request by its alternate request ID. + + Args: + request_id: The alternate request ID to search for. + + Returns: + The usage request if found, None otherwise. + """ + stmt = select(UsageRequestORM).where(UsageRequestORM.request_id == request_id) + return self._session.execute(stmt).scalars().one_or_none() + + async def create_many( + self, requests: list[UsageRequestORM] + ) -> tuple[list[UsageRequestORM], int]: + """Create multiple usage request records (idempotent). + + If a usage request with the same litellm_call_id already exists, + it is skipped via ON CONFLICT DO NOTHING. This method is designed + for bulk ingestion from collectors. + + Args: + requests: List of usage request entities to create. + + Returns: + Tuple of (created requests, skipped count). Duplicates are omitted. + """ + if not requests: + return [], 0 + + # Use PostgreSQL INSERT ... ON CONFLICT DO NOTHING for idempotency + # SQLite does not support this natively, so fall back to check-and-insert + dialect_name = self._session.bind.dialect.name if self._session.bind else "" + + if dialect_name == "postgresql": + return await self._create_many_postgres(requests) + else: + return await self._create_many_generic(requests) + + async def _create_many_postgres( + self, requests: list[UsageRequestORM] + ) -> tuple[list[UsageRequestORM], int]: + """Bulk insert with ON CONFLICT DO NOTHING (PostgreSQL only). + + Uses SQLAlchemy column inspection to build the value dict dynamically, + eliminating fragility from a hard-coded 27-column mapping. + Returns only the successfully inserted records to match the + generic check-and-insert semantics. + """ + call_ids = [r.litellm_call_id for r in requests] + + # Snapshot pre-existing IDs so we can exclude them from the + # post-insert query-back result (ON CONFLICT DO NOTHING may skip + # some rows and we must only return newly inserted ones). + pre_existing_stmt = select(UsageRequestORM.litellm_call_id).where( + UsageRequestORM.litellm_call_id.in_(call_ids) + ) + pre_existing = set(self._session.execute(pre_existing_stmt).scalars().all()) + + mapper = inspect(UsageRequestORM) + insert_stmt = pg_insert(UsageRequestORM).values( + [{col.name: getattr(r, col.name) for col in mapper.columns} for r in requests] + ) + insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["litellm_call_id"]) + + result = self._session.execute(insert_stmt) + skipped = len(requests) - result.rowcount # type: ignore[attr-defined] + + # Query back inserted rows, filtering out any pre-existing ones + created: list[UsageRequestORM] = [] + if result.rowcount: # type: ignore[attr-defined] + stmt = select(UsageRequestORM).where(UsageRequestORM.litellm_call_id.in_(call_ids)) + created = [ + r + for r in self._session.execute(stmt).scalars().all() + if r.litellm_call_id not in pre_existing + ] + + return created, skipped + + async def _create_many_generic( + self, requests: list[UsageRequestORM] + ) -> tuple[list[UsageRequestORM], int]: + """Fallback check-and-insert for SQLite and other dialects. + + Uses a per-item savepoint (begin_nested) to catch IntegrityError + without losing the outer transaction state, avoiding TOCTOU races. + """ + created: list[UsageRequestORM] = [] + skipped = 0 + + for request in requests: + nested = self._session.begin_nested() + try: + self._session.add(request) + self._session.flush() + created.append(request) + nested.commit() + except IntegrityError: + nested.rollback() + skipped += 1 + + return created, skipped + + async def list_by_key_alias( + self, + key_alias: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by key alias, ordered by start time desc. + + Args: + key_alias: The key alias to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.key_alias == key_alias) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_litellm_key_id( + self, + litellm_key_id: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by LiteLLM key ID. + + Args: + litellm_key_id: The LiteLLM key ID to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.litellm_key_id == litellm_key_id) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_model( + self, + model: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by resolved model. + + Args: + model: The resolved model to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.resolved_model == model) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_provider( + self, + provider: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by provider. + + Args: + provider: The provider slug to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.provider == provider) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_benchmark_session( + self, + benchmark_session_id: UUID, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests linked to a benchmark session. + + Args: + benchmark_session_id: The benchmark session UUID. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.benchmark_session_id == benchmark_session_id) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_time_range( + self, + start: datetime, + end: datetime, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests within a time range. + + Args: + start: Start of time range (datetime). + end: End of time range (datetime). + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of usage requests in the time range. + """ + stmt = ( + select(UsageRequestORM) + .where( + UsageRequestORM.started_at >= start, + UsageRequestORM.started_at <= end, + ) + .order_by(UsageRequestORM.started_at.desc()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_status( + self, + status: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by status. + + Args: + status: The status to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.status == status) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def list_by_error_code( + self, + error_code: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequestORM]: + """List usage requests by error code. + + Args: + error_code: The error code to filter by. + limit: Maximum number of results. + offset: Number of results to skip. + + Returns: + List of matching usage requests. + """ + stmt = ( + select(UsageRequestORM) + .where(UsageRequestORM.error_code == error_code) + .order_by(UsageRequestORM.started_at.desc().nulls_last()) + .limit(limit) + .offset(offset) + ) + return list(self._session.execute(stmt).scalars().all()) + + async def count_by_key_alias(self, key_alias: str) -> int: + """Count total usage requests for a key alias. + + Args: + key_alias: The key alias. + + Returns: + Number of matching usage requests. + """ + stmt = ( + select(func.count()) + .select_from(UsageRequestORM) + .where(UsageRequestORM.key_alias == key_alias) + ) + return self._session.execute(stmt).scalar() or 0 + + async def count_by_model(self, model: str) -> int: + """Count total usage requests for a model. + + Args: + model: The resolved model. + + Returns: + Number of matching usage requests. + """ + stmt = ( + select(func.count()) + .select_from(UsageRequestORM) + .where(UsageRequestORM.resolved_model == model) + ) + return self._session.execute(stmt).scalar() or 0 + + async def delete_by_benchmark_session(self, benchmark_session_id: UUID) -> int: + """Delete all usage requests linked to a benchmark session. + + Args: + benchmark_session_id: The benchmark session UUID. + + Returns: + Number of usage requests deleted. + """ + stmt = delete(UsageRequestORM).where( + UsageRequestORM.benchmark_session_id == benchmark_session_id + ) + result = cast(CursorResult, self._session.execute(stmt)) + self._session.flush() + return result.rowcount diff --git a/src/benchmark_core/repositories_abc.py b/src/benchmark_core/repositories_abc.py index d9fa325..5797b34 100644 --- a/src/benchmark_core/repositories_abc.py +++ b/src/benchmark_core/repositories_abc.py @@ -1,9 +1,10 @@ """Repository interfaces for session, request, and artifact storage.""" from abc import ABC, abstractmethod +from datetime import datetime from uuid import UUID -from benchmark_core.models import Artifact, ProxyCredential, Request, Session +from benchmark_core.models import Artifact, ProxyCredential, Request, Session, UsageRequest class ProxyCredentialRepository(ABC): @@ -98,3 +99,123 @@ async def create(self, request: Request) -> Request: async def get_by_request_id(self, request_id: str) -> Request | None: """Retrieve a request by its LiteLLM request ID.""" ... + + +class UsageRequestRepository(ABC): + """Abstract repository for usage request persistence.""" + + @abstractmethod + async def get_by_litellm_call_id(self, litellm_call_id: str) -> UsageRequest | None: + """Retrieve a usage request by its LiteLLM call ID.""" + ... + + @abstractmethod + async def get_by_request_id(self, request_id: str) -> UsageRequest | None: + """Retrieve a usage request by its alternate request ID.""" + ... + + @abstractmethod + async def create_many(self, requests: list[UsageRequest]) -> tuple[list[UsageRequest], int]: + """Create multiple usage request records (idempotent).""" + ... + + @abstractmethod + async def list_by_key_alias( + self, + key_alias: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by key alias.""" + ... + + @abstractmethod + async def list_by_litellm_key_id( + self, + litellm_key_id: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by LiteLLM key ID.""" + ... + + @abstractmethod + async def list_by_model( + self, + model: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by resolved model.""" + ... + + @abstractmethod + async def list_by_provider( + self, + provider: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by provider.""" + ... + + @abstractmethod + async def list_by_benchmark_session( + self, + benchmark_session_id: UUID, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests linked to a benchmark session.""" + ... + + @abstractmethod + async def list_by_time_range( + self, + start: datetime, + end: datetime, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests within a time range.""" + ... + + @abstractmethod + async def list_by_status( + self, + status: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by status.""" + ... + + @abstractmethod + async def list_by_error_code( + self, + error_code: str, + limit: int = 1000, + offset: int = 0, + ) -> list[UsageRequest]: + """List usage requests by error code.""" + ... + + @abstractmethod + async def count_by_key_alias(self, key_alias: str) -> int: + """Count total usage requests for a key alias.""" + ... + + @abstractmethod + async def count_by_model(self, model: str) -> int: + """Count total usage requests for a model.""" + ... + + @abstractmethod + async def delete(self, id: UUID) -> bool: + """Delete a usage request by its ID.""" + ... + + @abstractmethod + async def delete_by_benchmark_session(self, benchmark_session_id: UUID) -> int: + """Delete all usage requests linked to a benchmark session.""" + ... diff --git a/tests/unit/test_repositories.py b/tests/unit/test_repositories.py index 4b0ee8b..7e7d1f2 100644 --- a/tests/unit/test_repositories.py +++ b/tests/unit/test_repositories.py @@ -29,6 +29,7 @@ from benchmark_core.repositories.request_repository import SQLRequestRepository from benchmark_core.repositories.session_repository import SQLSessionRepository from benchmark_core.repositories.task_card_repository import SQLTaskCardRepository +from benchmark_core.repositories.usage_request_repository import SQLUsageRequestRepository from benchmark_core.repositories.variant_repository import SQLVariantRepository @@ -1262,3 +1263,586 @@ async def test_delete_nonexistent_artifact(self, db_session, artifact_repo): fake_id = uuid4() deleted = await artifact_repo.delete(fake_id) assert deleted is False + + +@pytest.fixture +def usage_request_repo(db_session): + """Create a usage request repository.""" + return SQLUsageRequestRepository(db_session) + + +class TestUsageRequestRepository: + """Tests for SQLUsageRequestRepository.""" + + @pytest.fixture + async def setup_experiment_variant_task(self, db_session): + """Create prerequisite entities for usage request tests.""" + experiment = Experiment(name="usage-test-exp") + db_session.add(experiment) + db_session.flush() + + variant = Variant( + name="usage-test-variant", + provider="test-provider", + model_alias="gpt-4o", + harness_profile="default", + ) + db_session.add(variant) + db_session.flush() + + task_card = TaskCard( + name="usage-test-task", + goal="Test usage request", + starting_prompt="Start", + stop_condition="Stop", + ) + db_session.add(task_card) + db_session.flush() + + db_session.commit() + return experiment, variant, task_card + + @pytest.fixture + async def setup_benchmark_session(self, db_session, setup_experiment_variant_task): + """Create a benchmark session for linkage tests.""" + experiment, variant, task_card = setup_experiment_variant_task + + session = SessionORM( + experiment_id=experiment.id, + variant_id=variant.id, + task_card_id=task_card.id, + harness_profile="default", + repo_path="/tmp/test", + git_branch="main", + git_commit="abc1234", + status="active", + ) + db_session.add(session) + db_session.commit() + return session + + @pytest.fixture + async def setup_proxy_key(self, db_session): + """Create a proxy key for linkage tests.""" + from benchmark_core.db.models import ProxyKey as ProxyKeyORM + + proxy_key = ProxyKeyORM( + key_alias="usage-test-key", + litellm_key_id="litellm-test-123", + status="active", + ) + db_session.add(proxy_key) + db_session.commit() + return proxy_key + + async def test_create_without_session(self, usage_request_repo, db_session): + """Usage rows can be persisted without a benchmark session.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage = UsageRequestORM( + litellm_call_id="call-no-session-001", + key_alias="standalone-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + input_tokens=10, + output_tokens=20, + latency_ms=500.0, + ) + created = await usage_request_repo.create(usage) + db_session.commit() + + assert created.id is not None + assert created.litellm_call_id == "call-no-session-001" + assert created.benchmark_session_id is None + assert created.key_alias == "standalone-key" + + async def test_create_with_session( + self, usage_request_repo, db_session, setup_benchmark_session + ): + """Usage rows can optionally link to a benchmark session.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + session = setup_benchmark_session + usage = UsageRequestORM( + litellm_call_id="call-with-session-001", + key_alias="session-key", + benchmark_session_id=session.id, + provider="openai", + resolved_model="gpt-4o", + status="success", + input_tokens=10, + output_tokens=20, + ) + created = await usage_request_repo.create(usage) + db_session.commit() + + assert created.benchmark_session_id == session.id + + async def test_duplicate_litellm_call_id(self, usage_request_repo, db_session): + """Duplicate LiteLLM request IDs do not create duplicate usage rows.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage1 = UsageRequestORM( + litellm_call_id="duplicate-call-001", + key_alias="key-1", + provider="openai", + resolved_model="gpt-4o", + status="success", + input_tokens=10, + output_tokens=20, + ) + await usage_request_repo.create(usage1) + db_session.commit() + + usage2 = UsageRequestORM( + litellm_call_id="duplicate-call-001", + key_alias="key-2", + provider="anthropic", + resolved_model="claude-3", + status="failure", + input_tokens=5, + output_tokens=0, + ) + from sqlalchemy.exc import IntegrityError + + with pytest.raises(IntegrityError): + await usage_request_repo.create(usage2) + db_session.commit() + + # Rollback the failed transaction so the session is clean + db_session.rollback() + + # Verify only the first row exists + found = await usage_request_repo.get_by_litellm_call_id("duplicate-call-001") + assert found is not None + assert found.provider == "openai" + assert found.resolved_model == "gpt-4o" + + async def test_no_content_fields_stored(self, usage_request_repo, db_session): + """No prompt/response content fields are stored by default.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage = UsageRequestORM( + litellm_call_id="content-check-001", + key_alias="key-1", + provider="openai", + resolved_model="gpt-4o", + status="success", + request_metadata={"stream": False}, + ) + created = await usage_request_repo.create(usage) + db_session.commit() + + # Verify the model does not have prompt/response columns + from benchmark_core.db.models import UsageRequest + + columns = {col.name for col in UsageRequest.__table__.columns} + assert "prompt" not in columns + assert "response" not in columns + assert "prompt_text" not in columns + assert "response_text" not in columns + assert "messages" not in columns + assert created.request_metadata == {"stream": False} + + async def test_create_many_idempotent(self, usage_request_repo, db_session): + """Idempotent create_many skips duplicates.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + req1 = UsageRequestORM( + litellm_call_id="batch-001", + key_alias="batch-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + req2 = UsageRequestORM( + litellm_call_id="batch-002", + key_alias="batch-key", + provider="openai", + resolved_model="gpt-4o-mini", + status="success", + ) + req3 = UsageRequestORM( + litellm_call_id="batch-001", # duplicate of req1 + key_alias="batch-key", + provider="anthropic", + resolved_model="claude-3", + status="failure", + ) + + created, skipped = await usage_request_repo.create_many([req1, req2, req3]) + db_session.commit() + + assert len(created) == 2 + assert skipped == 1 + + all_by_alias = await usage_request_repo.list_by_key_alias("batch-key") + assert len(all_by_alias) == 2 + call_ids = {r.litellm_call_id for r in all_by_alias} + assert "batch-001" in call_ids + assert "batch-002" in call_ids + + async def test_create_many_preserves_first_record(self, usage_request_repo, db_session): + """End-to-end: inserting a duplicate must not overwrite or return + the pre-existing row, and the caller must receive only the newly + inserted items in the result tuple.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + original = UsageRequestORM( + litellm_call_id="dup-001", + key_alias="dup-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(original) + db_session.commit() + original_id = original.id + + duplicate = UsageRequestORM( + litellm_call_id="dup-001", + key_alias="dup-key", + provider="anthropic", + resolved_model="claude-3", + status="failure", + ) + new = UsageRequestORM( + litellm_call_id="dup-002", + key_alias="dup-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + + created, skipped = await usage_request_repo.create_many([duplicate, new]) + db_session.commit() + + assert skipped == 1 + assert len(created) == 1 + assert created[0].litellm_call_id == "dup-002" + + # Original row must be untouched + row = await usage_request_repo.get_by_litellm_call_id("dup-001") + assert row is not None + assert row.id == original_id + assert row.provider == "openai" + assert row.resolved_model == "gpt-4o" + assert row.status == "success" + + async def test_get_by_litellm_call_id(self, usage_request_repo, db_session): + """Retrieve usage request by LiteLLM call ID.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage = UsageRequestORM( + litellm_call_id="find-me-001", + key_alias="find-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + found = await usage_request_repo.get_by_litellm_call_id("find-me-001") + assert found is not None + assert found.litellm_call_id == "find-me-001" + + not_found = await usage_request_repo.get_by_litellm_call_id("does-not-exist") + assert not_found is None + + async def test_list_by_key_alias(self, usage_request_repo, db_session): + """List usage requests by key alias.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(3): + usage = UsageRequestORM( + litellm_call_id=f"alias-list-{i}", + key_alias="shared-alias", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + results = await usage_request_repo.list_by_key_alias("shared-alias") + assert len(results) == 3 + + async def test_list_by_model(self, usage_request_repo, db_session): + """List usage requests by resolved model.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(2): + usage = UsageRequestORM( + litellm_call_id=f"model-gpt-{i}", + key_alias="model-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + + usage_other = UsageRequestORM( + litellm_call_id="model-claude-001", + key_alias="model-key", + provider="anthropic", + resolved_model="claude-3", + status="success", + ) + await usage_request_repo.create(usage_other) + db_session.commit() + + results = await usage_request_repo.list_by_model("gpt-4o") + assert len(results) == 2 + + async def test_list_by_provider(self, usage_request_repo, db_session): + """List usage requests by provider.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(2): + usage = UsageRequestORM( + litellm_call_id=f"prov-openai-{i}", + key_alias="prov-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + + usage_other = UsageRequestORM( + litellm_call_id="prov-fireworks-001", + key_alias="prov-key", + provider="fireworks", + resolved_model="kimi-k2-5", + status="success", + ) + await usage_request_repo.create(usage_other) + db_session.commit() + + results = await usage_request_repo.list_by_provider("openai") + assert len(results) == 2 + + async def test_list_by_benchmark_session( + self, usage_request_repo, db_session, setup_benchmark_session + ): + """List usage requests linked to a benchmark session.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + session = setup_benchmark_session + for i in range(2): + usage = UsageRequestORM( + litellm_call_id=f"session-link-{i}", + key_alias="session-key", + benchmark_session_id=session.id, + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + results = await usage_request_repo.list_by_benchmark_session(session.id) + assert len(results) == 2 + for r in results: + assert r.benchmark_session_id == session.id + + async def test_list_by_time_range(self, usage_request_repo, db_session): + """List usage requests within a time range.""" + from datetime import UTC, datetime + + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage1 = UsageRequestORM( + litellm_call_id="time-001", + key_alias="time-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + started_at=datetime(2025, 1, 15, 10, 0, 0, tzinfo=UTC), + ) + usage2 = UsageRequestORM( + litellm_call_id="time-002", + key_alias="time-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + started_at=datetime(2025, 1, 16, 10, 0, 0, tzinfo=UTC), + ) + usage3 = UsageRequestORM( + litellm_call_id="time-003", + key_alias="time-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + started_at=datetime(2025, 1, 17, 10, 0, 0, tzinfo=UTC), + ) + await usage_request_repo.create(usage1) + await usage_request_repo.create(usage2) + await usage_request_repo.create(usage3) + db_session.commit() + + results = await usage_request_repo.list_by_time_range( + datetime(2025, 1, 15, 0, 0, 0, tzinfo=UTC), + datetime(2025, 1, 16, 23, 59, 59, tzinfo=UTC), + ) + assert len(results) == 2 + call_ids = {r.litellm_call_id for r in results} + assert "time-001" in call_ids + assert "time-002" in call_ids + + async def test_list_by_status(self, usage_request_repo, db_session): + """List usage requests by status.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + success = UsageRequestORM( + litellm_call_id="status-success-001", + key_alias="status-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + failure = UsageRequestORM( + litellm_call_id="status-failure-001", + key_alias="status-key", + provider="openai", + resolved_model="gpt-4o", + status="failure", + ) + await usage_request_repo.create(success) + await usage_request_repo.create(failure) + db_session.commit() + + results = await usage_request_repo.list_by_status("success") + assert len(results) == 1 + assert results[0].litellm_call_id == "status-success-001" + + async def test_list_by_error_code(self, usage_request_repo, db_session): + """List usage requests by error code.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + rate_limited = UsageRequestORM( + litellm_call_id="error-429-001", + key_alias="error-key", + provider="openai", + resolved_model="gpt-4o", + status="failure", + error_code="429", + ) + server_error = UsageRequestORM( + litellm_call_id="error-500-001", + key_alias="error-key", + provider="openai", + resolved_model="gpt-4o", + status="failure", + error_code="500", + ) + await usage_request_repo.create(rate_limited) + await usage_request_repo.create(server_error) + db_session.commit() + + results = await usage_request_repo.list_by_error_code("429") + assert len(results) == 1 + assert results[0].litellm_call_id == "error-429-001" + + async def test_count_by_key_alias(self, usage_request_repo, db_session): + """Count usage requests by key alias.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(5): + usage = UsageRequestORM( + litellm_call_id=f"count-{i}", + key_alias="countable-alias", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + count = await usage_request_repo.count_by_key_alias("countable-alias") + assert count == 5 + + async def test_count_by_model(self, usage_request_repo, db_session): + """Count usage requests by model.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(3): + usage = UsageRequestORM( + litellm_call_id=f"model-count-{i}", + key_alias="count-key", + provider="openai", + resolved_model="gpt-4o-mini", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + count = await usage_request_repo.count_by_model("gpt-4o-mini") + assert count == 3 + + async def test_delete_usage_request(self, usage_request_repo, db_session): + """Delete a usage request by ID.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + usage = UsageRequestORM( + litellm_call_id="delete-me-001", + key_alias="delete-key", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + created = await usage_request_repo.create(usage) + db_session.commit() + + deleted = await usage_request_repo.delete(created.id) + db_session.commit() + + assert deleted is True + not_found = await usage_request_repo.get_by_id(created.id) + assert not_found is None + + async def test_delete_by_benchmark_session( + self, usage_request_repo, db_session, setup_benchmark_session + ): + """Delete all usage requests linked to a benchmark session.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + session = setup_benchmark_session + for i in range(3): + usage = UsageRequestORM( + litellm_call_id=f"batch-delete-{i}", + key_alias="batch-delete-key", + benchmark_session_id=session.id, + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + deleted_count = await usage_request_repo.delete_by_benchmark_session(session.id) + db_session.commit() + + assert deleted_count == 3 + results = await usage_request_repo.list_by_benchmark_session(session.id) + assert len(results) == 0 + + async def test_list_by_litellm_key_id(self, usage_request_repo, db_session): + """List usage requests by LiteLLM key ID.""" + from benchmark_core.db.models import UsageRequest as UsageRequestORM + + for i in range(2): + usage = UsageRequestORM( + litellm_call_id=f"key-id-{i}", + key_alias="key-id-alias", + litellm_key_id="litellm-key-abc", + provider="openai", + resolved_model="gpt-4o", + status="success", + ) + await usage_request_repo.create(usage) + db_session.commit() + + results = await usage_request_repo.list_by_litellm_key_id("litellm-key-abc") + assert len(results) == 2 diff --git a/tests/validation/test_migrations.py b/tests/validation/test_migrations.py index 99cb1c7..1c7f5be 100644 --- a/tests/validation/test_migrations.py +++ b/tests/validation/test_migrations.py @@ -66,6 +66,7 @@ def test_init_db_creates_all_tables(self, temp_db) -> None: "artifacts", "proxy_credentials", "proxy_keys", + "usage_requests", ] for table in expected_tables: