diff --git a/README.md b/README.md index 13736ce..c13216d 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,8 @@ client = Client(config) # Client construction does not start browser login automatically. # Opens the system browser and completes PKCE on loopback: try: - client.authenticate_oauth_pkce() + if not client.has_cached_oauth_token(): + client.authenticate_oauth_pkce() except OAuthPkceError as e: print(f"Login failed: {e}") ``` diff --git a/dir-sdk-python/agntcy/dir_sdk/__init__.py b/dir-sdk-python/agntcy/dir_sdk/__init__.py index e69de29..9ae9aeb 100644 --- a/dir-sdk-python/agntcy/dir_sdk/__init__.py +++ b/dir-sdk-python/agntcy/dir_sdk/__init__.py @@ -0,0 +1,2 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 diff --git a/dir-sdk-python/agntcy/dir_sdk/client/__init__.py b/dir-sdk-python/agntcy/dir_sdk/client/__init__.py index 59b856c..a6b3d4c 100644 --- a/dir-sdk-python/agntcy/dir_sdk/client/__init__.py +++ b/dir-sdk-python/agntcy/dir_sdk/client/__init__.py @@ -3,4 +3,4 @@ from agntcy.dir_sdk.client.client import Client from agntcy.dir_sdk.client.config import Config -from agntcy.dir_sdk.client.oauth_pkce import OAuthPkceError as OAuthPkceError +from agntcy.dir_sdk.client.auth.oauth_pkce import OAuthPkceError as OAuthPkceError diff --git a/dir-sdk-python/agntcy/dir_sdk/client/auth/__init__.py b/dir-sdk-python/agntcy/dir_sdk/client/auth/__init__.py new file mode 100644 index 0000000..1d2e3c9 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/auth/__init__.py @@ -0,0 +1,10 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Authentication/session helpers for the Directory client.""" + +from agntcy.dir_sdk.client.auth.oauth_pkce import OAuthTokenHolder +from agntcy.dir_sdk.client.auth.session import OAuthSessionManager, cached_token_from_response +from agntcy.dir_sdk.client.auth.token_cache import CachedToken, TokenCache + +__all__ = ["CachedToken", "OAuthSessionManager", "OAuthTokenHolder", "TokenCache", "cached_token_from_response"] diff --git a/dir-sdk-python/agntcy/dir_sdk/client/oauth_pkce.py b/dir-sdk-python/agntcy/dir_sdk/client/auth/oauth_pkce.py similarity index 100% rename from dir-sdk-python/agntcy/dir_sdk/client/oauth_pkce.py rename to dir-sdk-python/agntcy/dir_sdk/client/auth/oauth_pkce.py diff --git a/dir-sdk-python/agntcy/dir_sdk/client/auth/session.py b/dir-sdk-python/agntcy/dir_sdk/client/auth/session.py new file mode 100644 index 0000000..1369d8b --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/auth/session.py @@ -0,0 +1,93 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""OAuth session management and token cache integration.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +from agntcy.dir_sdk.client.config import Config +from agntcy.dir_sdk.client.auth.oauth_pkce import ( + OAuthTokenHolder, + fetch_openid_configuration, + run_loopback_pkce_login, +) +from agntcy.dir_sdk.client.auth.token_cache import CachedToken, TokenCache + + +def cached_token_from_response(config: Config, payload: dict[str, object]) -> CachedToken: + expires_at = None + expires_in = payload.get("expires_in") + if expires_in is not None: + expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) + + refresh_token = payload.get("refresh_token") + token_type = payload.get("token_type") + return CachedToken( + access_token=str(payload["access_token"]), + token_type=str(token_type) if isinstance(token_type, str) else "", + provider="oidc", + issuer=config.oidc_issuer, + refresh_token=str(refresh_token) if isinstance(refresh_token, str) else "", + expires_at=expires_at, + created_at=datetime.now(UTC), + ) + + +class OAuthSessionManager: + """Coordinates OIDC token state with interactive PKCE flow and cache.""" + + def __init__( + self, + config: Config, + token_cache: TokenCache | None = None, + ) -> None: + self.config = config + self._token_cache = token_cache or TokenCache() + self._oauth_holder: OAuthTokenHolder | None = None + + if self.config.auth_mode == "oidc": + self._oauth_holder = OAuthTokenHolder() + if self.config.auth_token: + self._oauth_holder.set_tokens(self.config.auth_token) + else: + cached_token = self._token_cache.get_valid_token() + if cached_token is not None: + self._oauth_holder.set_tokens(cached_token.access_token) + + @property + def oauth_holder(self) -> OAuthTokenHolder | None: + return self._oauth_holder + + def has_access_token(self) -> bool: + if self._oauth_holder is None: + return False + try: + self._oauth_holder.get_access_token() + return True + except RuntimeError: + return False + + def authenticate(self) -> None: + if self.config.auth_mode != "oidc": + msg = "authenticate_oauth_pkce() requires auth_mode='oidc'" + raise ValueError(msg) + if not self.config.oidc_issuer: + msg = "oidc_issuer is required for authenticate_oauth_pkce()" + raise ValueError(msg) + if not self.config.oidc_client_id: + msg = "oidc_client_id is required for authenticate_oauth_pkce()" + raise ValueError(msg) + if self._oauth_holder is None: + msg = "OAuth token holder not initialized" + raise RuntimeError(msg) + + meta = fetch_openid_configuration( + self.config.oidc_issuer, + verify=not self.config.tls_skip_verify, + timeout=min(30.0, self.config.oidc_auth_timeout), + ) + payload = run_loopback_pkce_login(self.config, metadata=meta) + self._oauth_holder.update_from_token_response(payload) + self._token_cache.save(cached_token_from_response(self.config, payload)) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/token_cache.py b/dir-sdk-python/agntcy/dir_sdk/client/auth/token_cache.py similarity index 100% rename from dir-sdk-python/agntcy/dir_sdk/client/token_cache.py rename to dir-sdk-python/agntcy/dir_sdk/client/auth/token_cache.py diff --git a/dir-sdk-python/agntcy/dir_sdk/client/client.py b/dir-sdk-python/agntcy/dir_sdk/client/client.py index 1848ad2..9778f64 100644 --- a/dir-sdk-python/agntcy/dir_sdk/client/client.py +++ b/dir-sdk-python/agntcy/dir_sdk/client/client.py @@ -1,212 +1,74 @@ # Copyright AGNTCY Contributors (https://github.com/agntcy) # SPDX-License-Identifier: Apache-2.0 -"""Client module for the AGNTCY Directory service. +"""High-level facade for AGNTCY Directory client operations.""" -This module provides a high-level Python client for interacting with the AGNTCY -Directory services including routing, search, store, and signing operations. -""" +from __future__ import annotations import builtins import logging -import os -import json -import subprocess -import tempfile -from collections.abc import Callable, Sequence -from datetime import UTC, datetime, timedelta +from collections.abc import Sequence import grpc -from google.protobuf import json_format -from cryptography.hazmat.primitives import serialization -from spiffe import WorkloadApiClient, X509Source +from agntcy.dir_sdk.client.auth.session import OAuthSessionManager from agntcy.dir_sdk.client.config import Config -from agntcy.dir_sdk.client.oauth_pkce import ( +from agntcy.dir_sdk.client.auth.oauth_pkce import ( OAuthTokenHolder, fetch_openid_configuration, run_loopback_pkce_login, ) -from agntcy.dir_sdk.client.token_cache import CachedToken, TokenCache +from agntcy.dir_sdk.client.services.events import EventService +from agntcy.dir_sdk.client.services.naming import NamingService +from agntcy.dir_sdk.client.services.publication import PublicationService +from agntcy.dir_sdk.client.services.routing import RoutingService +from agntcy.dir_sdk.client.services.search import SearchService +from agntcy.dir_sdk.client.services.signing import SignService +from agntcy.dir_sdk.client.services.store import StoreService +from agntcy.dir_sdk.client.services.sync import SyncService +from agntcy.dir_sdk.client.auth.token_cache import CachedToken, TokenCache +from agntcy.dir_sdk.client.transport.channels import create_grpc_channel +from agntcy.dir_sdk.client.transport.interceptors import ( + BearerAuthInterceptor, + JWTAuthInterceptor, +) from agntcy.dir_sdk.models import ( core_v1, events_v1, - routing_v1, - store_v1, naming_v1, - sign_v1, + routing_v1, search_v1, + sign_v1, + store_v1, ) logger = logging.getLogger("client") - -class JWTAuthInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): - """gRPC interceptor that adds JWT-SVID authentication to requests.""" - - def __init__(self, socket_path: str, audience: str) -> None: - """Initialize the JWT auth interceptor. - - Args: - socket_path: Path to the SPIFFE Workload API socket - audience: JWT audience claim for token validation - - """ - self.socket_path = socket_path - self.audience = audience - self._workload_client = WorkloadApiClient(socket_path=socket_path) - - def _get_jwt_token(self) -> str: - """Fetch a JWT-SVID from the SPIRE Workload API. - - Returns: - JWT token string - - Raises: - RuntimeError: If unable to fetch JWT-SVID - - """ - try: - # Fetch JWT-SVID with the configured audience - jwt_svid = self._workload_client.fetch_jwt_svid(audience=[self.audience]) - if jwt_svid and jwt_svid.token: - return jwt_svid.token - msg = "Failed to fetch JWT-SVID: empty token" - raise RuntimeError(msg) - except Exception as e: - msg = f"Failed to fetch JWT-SVID: {e}" - raise RuntimeError(msg) from e - - def _add_jwt_metadata(self, client_call_details): - """Add JWT token to request metadata.""" - token = self._get_jwt_token() - metadata = [] - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - metadata.append(("authorization", f"Bearer {token}")) - - return grpc._interceptor._ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - compression=client_call_details.compression, - ) - - def intercept_unary_unary(self, continuation, client_call_details, request): - """Intercept unary-unary RPC calls.""" - new_details = self._add_jwt_metadata(client_call_details) - return continuation(new_details, request) - - def intercept_unary_stream(self, continuation, client_call_details, request): - """Intercept unary-stream RPC calls.""" - new_details = self._add_jwt_metadata(client_call_details) - return continuation(new_details, request) - - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): - """Intercept stream-unary RPC calls.""" - new_details = self._add_jwt_metadata(client_call_details) - return continuation(new_details, request_iterator) - - def intercept_stream_stream(self, continuation, client_call_details, request_iterator): - """Intercept stream-stream RPC calls.""" - new_details = self._add_jwt_metadata(client_call_details) - return continuation(new_details, request_iterator) - - -class BearerAuthInterceptor( - grpc.UnaryUnaryClientInterceptor, - grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor, -): - """gRPC interceptor that adds a static OAuth Bearer access token to requests.""" - - def __init__(self, token_supplier: Callable[[], str]) -> None: - self._token_supplier = token_supplier - - def _add_bearer_metadata(self, client_call_details): - token = self._token_supplier() - metadata = [] - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - metadata.append(("authorization", f"Bearer {token}")) - - return grpc._interceptor._ClientCallDetails( - method=client_call_details.method, - timeout=client_call_details.timeout, - metadata=metadata, - credentials=client_call_details.credentials, - wait_for_ready=client_call_details.wait_for_ready, - compression=client_call_details.compression, - ) - - def intercept_unary_unary(self, continuation, client_call_details, request): - new_details = self._add_bearer_metadata(client_call_details) - return continuation(new_details, request) - - def intercept_unary_stream(self, continuation, client_call_details, request): - new_details = self._add_bearer_metadata(client_call_details) - return continuation(new_details, request) - - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): - new_details = self._add_bearer_metadata(client_call_details) - return continuation(new_details, request_iterator) - - def intercept_stream_stream(self, continuation, client_call_details, request_iterator): - new_details = self._add_bearer_metadata(client_call_details) - return continuation(new_details, request_iterator) +__all__ = [ + "BearerAuthInterceptor", + "CachedToken", + "Client", + "JWTAuthInterceptor", + "OAuthTokenHolder", + "TokenCache", + "fetch_openid_configuration", + "run_loopback_pkce_login", +] class Client: - """High-level client for interacting with AGNTCY Directory services. - - This client provides a unified interface for operations across Dir API. - It handles gRPC communication and provides convenient methods for common operations. - - Example: - >>> config = Config.load_from_env() - >>> client = Client(config) - >>> # Use client for operations... - - """ - - def __init__( - self, - config: Config | None = None, - ) -> None: - """Initialize the client with the given configuration. - - Args: - config: Optional client configuration. If None, loads from environment - variables using Config.load_from_env(). - - Raises: - grpc.RpcError: If unable to establish connection to the server - ValueError: If configuration is invalid + """High-level client for interacting with AGNTCY Directory services.""" - """ - # Load config if unset - if config is None: - config = Config.load_from_env() - self.config = config - self._oauth_holder: OAuthTokenHolder | None = None + def __init__(self, config: Config | None = None) -> None: + self.config = config or Config.load_from_env() + self.oauth_session = OAuthSessionManager(self.config) - if config.auth_mode == "oidc": - self._oauth_holder = OAuthTokenHolder() - if self.config.auth_token: - self._oauth_holder.set_tokens(self.config.auth_token) - else: - cached_token = TokenCache().get_valid_token() - if cached_token is not None: - self._oauth_holder.set_tokens(cached_token.access_token) - - # Create gRPC channel - channel = self.__create_grpc_channel() + channel = create_grpc_channel( + self.config, + oauth_holder=self.oauth_session.oauth_holder, + ) - # Initialize service clients + # Expose raw stubs for advanced callers. self.store_client = store_v1.StoreServiceStub(channel) self.routing_client = routing_v1.RoutingServiceStub(channel) self.publication_client = routing_v1.PublicationServiceStub(channel) @@ -216,1005 +78,163 @@ def __init__( self.event_client = events_v1.EventServiceStub(channel) self.naming_client = naming_v1.NamingServiceStub(channel) - def __create_grpc_channel(self) -> grpc.Channel: - # Handle different authentication modes - if self.config.auth_mode == "": - return grpc.insecure_channel(self.config.server_address) - elif self.config.auth_mode == "jwt": - return self.__create_jwt_channel() - elif self.config.auth_mode == "x509": - return self.__create_x509_channel() - elif self.config.auth_mode == "tls": - return self.__create_tls_channel() - elif self.config.auth_mode == "oidc": - return self.__create_oauth_pkce_channel() - else: - msg = f"Unsupported auth mode: {self.config.auth_mode}" - raise ValueError(msg) - - def __create_x509_channel(self) -> grpc.Channel: - """Create a secure gRPC channel using SPIFFE X.509.""" - if self.config.spiffe_socket_path == "": - msg = "SPIFFE socket path is required for X.509 authentication" - raise ValueError(msg) - - # Create secure gRPC channel using SPIFFE X.509 - workload_client = WorkloadApiClient(socket_path=self.config.spiffe_socket_path) - x509_src = X509Source( - workload_api_client=workload_client, - socket_path=self.config.spiffe_socket_path, - timeout_in_seconds=60, - ) - - root_ca = b"" - for b in x509_src.bundles: - for a in b.x509_authorities: - root_ca += a.public_bytes(encoding=serialization.Encoding.PEM) - - private_key = x509_src.svid.private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - - public_leaf = x509_src.svid.leaf.public_bytes( - encoding=serialization.Encoding.PEM - ) - - credentials = grpc.ssl_channel_credentials( - root_certificates=root_ca, - private_key=private_key, - certificate_chain=public_leaf, - ) - - channel = grpc.secure_channel( - target=self.config.server_address, - credentials=credentials, - options=self._grpc_channel_options(), - ) - - return channel - - def __create_jwt_channel(self) -> grpc.Channel: - """Create a gRPC channel with JWT authentication.""" - if self.config.spiffe_socket_path == "": - msg = "SPIFFE socket path is required for JWT authentication" - raise ValueError(msg) - - if self.config.jwt_audience == "": - msg = "JWT audience is required for JWT authentication" - raise ValueError(msg) - - # Create X509Source to get the SPIFFE bundle for TLS verification - # In JWT mode, the server presents its X.509-SVID via TLS for transport security - # The X509Source will handle fetching the bundle from the Workload API - workload_client = WorkloadApiClient(socket_path=self.config.spiffe_socket_path) - x509_source = X509Source( - workload_api_client=workload_client, - socket_path=self.config.spiffe_socket_path, - timeout_in_seconds=60, - ) - - # Extract the CA certificates from all bundles - root_ca = b"" - for bundle in x509_source.bundles: - for authority in bundle.x509_authorities: - root_ca += authority.public_bytes(encoding=serialization.Encoding.PEM) - - if not root_ca: - msg = "Failed to fetch X.509 bundle from SPIRE: no bundles returned" - raise RuntimeError(msg) - - # Create JWT interceptor - jwt_interceptor = JWTAuthInterceptor( - socket_path=self.config.spiffe_socket_path, - audience=self.config.jwt_audience - ) - - # Create secure channel with JWT interceptor and TLS using SPIFFE bundle - # For JWT mode: Server presents X.509-SVID via TLS, clients authenticate with JWT-SVID - credentials = grpc.ssl_channel_credentials(root_certificates=root_ca) - channel = grpc.secure_channel( - target=self.config.server_address, - credentials=credentials, - options=self._grpc_channel_options(), - ) - channel = grpc.intercept_channel(channel, jwt_interceptor) - - # Close the X509Source since we only needed it to get the bundle - x509_source.close() - - return channel - - def __create_tls_channel(self) -> grpc.Channel: - if not self.config.tls_ca_file: - msg = "TLS CA file is required for TLS authentication" - raise ValueError(msg) - if not self.config.tls_cert_file: - msg = "TLS certificate file is required for TLS authentication" - raise ValueError(msg) - if not self.config.tls_key_file: - msg = "TLS key file is required for TLS authentication" - raise ValueError(msg) - - try: - with open(self.config.tls_ca_file, "rb") as f: - root_ca = f.read() - with open(self.config.tls_cert_file, "rb") as f: - cert_chain = f.read() - with open(self.config.tls_key_file, "rb") as f: - private_key = f.read() - except OSError as e: - msg = f"Failed to read TLS files: {e}" - raise RuntimeError(msg) from e - - credentials = grpc.ssl_channel_credentials( - root_certificates=root_ca, - private_key=private_key, - certificate_chain=cert_chain, - ) - - channel = grpc.secure_channel( - target=self.config.server_address, - credentials=credentials, - options=self._grpc_channel_options(), - ) - - return channel - - def __create_oauth_pkce_channel(self) -> grpc.Channel: - if self._oauth_holder is None: + # Service-layer adapters grouped by technical area. + self.store_service = StoreService(self.store_client, logger) + self.routing_service = RoutingService(self.routing_client, logger) + self.publication_service = PublicationService(self.publication_client, logger) + self.search_service = SearchService(self.search_client, logger) + self.sign_service = SignService(self.config, self.sign_client, logger) + self.sync_service = SyncService(self.sync_client, logger) + self.event_service = EventService(self.event_client, logger) + self.naming_service = NamingService(self.naming_client, logger) + + def has_cached_oauth_token(self) -> bool: + return self.oauth_session.has_access_token() + + def get_access_token(self) -> str: + oauth_holder = self.oauth_session.oauth_holder + if oauth_holder is None: msg = "OAuth token holder not initialized" raise RuntimeError(msg) - - root_ca = None - if self.config.tls_ca_file: - try: - with open(self.config.tls_ca_file, "rb") as f: - root_ca = f.read() - except OSError as e: - msg = f"Failed to read TLS CA file: {e}" - raise RuntimeError(msg) from e - - credentials = grpc.ssl_channel_credentials(root_certificates=root_ca) - - channel = grpc.secure_channel( - target=self.config.server_address, - credentials=credentials, - options=self._grpc_channel_options(), - ) - - bearer = BearerAuthInterceptor(self._oauth_holder.get_access_token) - return grpc.intercept_channel(channel, bearer) + return oauth_holder.get_access_token() def authenticate_oauth_pkce(self) -> None: - """Run browser-based OAuth 2.0 Authorization Code + PKCE login (loopback callback). - - Requires ``auth_mode=\"oidc\"``, ``oidc_issuer``, and ``oidc_client_id``. - After success, gRPC calls use the returned access token for bearer auth. - - Raises: - ValueError: If auth mode or required OIDC settings are missing. - OAuthPkceError: If the authorization or token exchange fails. - - """ - if self.config.auth_mode != "oidc": - msg = "authenticate_oauth_pkce() requires auth_mode='oidc'" - raise ValueError(msg) - if not self.config.oidc_issuer: - msg = "oidc_issuer is required for authenticate_oauth_pkce()" - raise ValueError(msg) - if not self.config.oidc_client_id: - msg = "oidc_client_id is required for authenticate_oauth_pkce()" - raise ValueError(msg) - if self._oauth_holder is None: - msg = "OAuth token holder not initialized" - raise RuntimeError(msg) - - meta = fetch_openid_configuration( - self.config.oidc_issuer, - verify=not self.config.tls_skip_verify, - timeout=min(30.0, self.config.oidc_auth_timeout), - ) - - payload = run_loopback_pkce_login(self.config, metadata=meta) - self._oauth_holder.update_from_token_response(payload) - TokenCache().save(self._cached_token_from_response(payload)) - + self.oauth_session.authenticate() print("Authenticated with OAuth PKCE") - # Do not print raw OAuth credentials to stdout/logs. print("Access token acquired.") - def _cached_token_from_response(self, payload: dict[str, object]) -> CachedToken: - expires_at = None - expires_in = payload.get("expires_in") - if expires_in is not None: - expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) - - refresh_token = payload.get("refresh_token") - token_type = payload.get("token_type") - - return CachedToken( - access_token=str(payload["access_token"]), - token_type=str(token_type) if isinstance(token_type, str) else "", - provider="oidc", - issuer=self.config.oidc_issuer, - refresh_token=str(refresh_token) if isinstance(refresh_token, str) else "", - expires_at=expires_at, - created_at=datetime.now(UTC), - ) - - def _grpc_channel_options(self) -> list[tuple[str, str]]: - server_name = self.config.tls_server_name.strip() - if not server_name: - return [] - return [ - ("grpc.ssl_target_name_override", server_name), - ("grpc.default_authority", server_name), - ] - - def _server_name_from_addr(self, addr: str) -> str: - # "host:port" -> "host" - return addr.rsplit(":", 1)[0] - def publish( self, req: routing_v1.PublishRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> None: - """Publish objects to the Routing API matching the specified criteria. - - Makes the specified objects available for discovery and retrieval by other - clients in the network. The objects must already exist in the store before - they can be published. - - Args: - req: Publish request containing the query for the objects to publish - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the object is not found or cannot be published - - Example: - >>> ref = routing_v1.RecordRef(cid="QmExample123") - >>> req = routing_v1.PublishRequest(record_refs=[ref]) - >>> client.publish(req) - - """ - try: - self.routing_client.Publish(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during publish: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during publish: %s", e) - msg = f"Failed to publish object: {e}" - raise RuntimeError(msg) from e + self.routing_service.publish(req, metadata=metadata) def list( self, req: routing_v1.ListRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> list[routing_v1.ListResponse]: - """List objects from the Routing API matching the specified criteria. - - Returns a list of objects that match the filtering and - query criteria specified in the request. - - Args: - req: List request specifying filtering criteria, pagination, etc. - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[routing_v1.ListResponse]: List of items matching the criteria - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the list operation fails - - Example: - >>> req = routing_v1.ListRequest(limit=10) - >>> responses = client.list(req) - >>> for response in responses: - ... print(f"Found object: {response.cid}") - - """ - results: list[routing_v1.ListResponse] = [] - - try: - stream = self.routing_client.List(req, metadata=metadata) - results.extend(stream) - except grpc.RpcError as e: - logger.exception("gRPC error during list: %s", e) - raise - except Exception as e: - logger.exception("Error receiving objects: %s", e) - msg = f"Failed to list objects: {e}" - raise RuntimeError(msg) from e - - return results + return self.routing_service.list(req, metadata=metadata) def search_cids( self, req: search_v1.SearchCIDsRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[search_v1.SearchCIDsResponse]: - """Search for record CIDs matching the specified queries. - - Performs a search across the storage using the provided search queries - and returns a list of matching CIDs. This is efficient for lookups - where only the CIDs are needed. - - Args: - req: Search request containing queries, filters, and search options - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[search_v1.SearchCIDsResponse]: List of CIDs matching the queries - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the search operation fails - - Example: - >>> req = search_v1.SearchCIDsRequest(queries=[query], limit=10) - >>> responses = client.search_cids(req) - >>> for response in responses: - ... print(f"Found CID: {response.record_cid}") - - """ - results: list[search_v1.SearchCIDsResponse] = [] - - try: - stream = self.search_client.SearchCIDs(req, metadata=metadata) - results.extend(stream) - except grpc.RpcError as e: - logger.exception("gRPC error during search: %s", e) - raise - except Exception as e: - logger.exception("Error receiving search results: %s", e) - msg = f"Failed to search CIDs: {e}" - raise RuntimeError(msg) from e - - return results + return self.search_service.search_cids(req, metadata=metadata) def search_records( self, req: search_v1.SearchRecordsRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[search_v1.SearchRecordsResponse]: - """Search for full records matching the specified queries. - - Performs a search across the storage using the provided search queries - and returns a list of full records with all metadata. - - Args: - req: Search request containing queries, filters, and search options - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[search_v1.SearchRecordsResponse]: List of records matching the queries - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the search operation fails - - Example: - >>> req = search_v1.SearchRecordsRequest(queries=[query], limit=10) - >>> responses = client.search_records(req) - >>> for response in responses: - ... print(f"Found: {response.record.name}") - - """ - results: list[search_v1.SearchRecordsResponse] = [] - - try: - stream = self.search_client.SearchRecords(req, metadata=metadata) - results.extend(stream) - except grpc.RpcError as e: - logger.exception("gRPC error during search: %s", e) - raise - except Exception as e: - logger.exception("Error receiving search results: %s", e) - msg = f"Failed to search records: {e}" - raise RuntimeError(msg) from e - - return results + return self.search_service.search_records(req, metadata=metadata) def unpublish( self, req: routing_v1.UnpublishRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> None: - """Unpublish objects from the Routing API matching the specified criteria. - - Removes the specified objects from the public network, making them no - longer discoverable by other clients. The objects remain in the local - store but are not available for network discovery. - - Args: - req: Unpublish request containing the query for the objects to unpublish - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the objects cannot be unpublished - - Example: - >>> ref = routing_v1.RecordRef(cid="QmExample123") - >>> req = routing_v1.UnpublishRequest(record_refs=[ref]) - >>> client.unpublish(req) - - """ - try: - self.routing_client.Unpublish(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during unpublish: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during unpublish: %s", e) - msg = f"Failed to unpublish object: {e}" - raise RuntimeError(msg) from e + self.routing_service.unpublish(req, metadata=metadata) def push( self, records: builtins.list[core_v1.Record], metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[core_v1.RecordRef]: - """Push records to the Store API. - - Uploads one or more records to the content store, making them available - for retrieval and reference. Each record is assigned a unique content - identifier (CID) based on its content hash. - - Args: - records: List of Record objects to push to the store - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[core_v1.RecordRef]: List of objects containing the CIDs of the pushed records - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the push operation fails - - Example: - >>> records = [create_record("example")] - >>> refs = client.push(records) - >>> print(f"Pushed with CID: {refs[0].cid}") - - """ - results: list[core_v1.RecordRef] = [] - - try: - response = self.store_client.Push(iter(records), metadata=metadata) - results.extend(response) - except grpc.RpcError as e: - logger.exception("gRPC error during push: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during push: %s", e) - msg = f"Failed to push object: {e}" - raise RuntimeError(msg) from e - - return results + return self.store_service.push(records, metadata=metadata) def push_referrer( self, req: builtins.list[store_v1.PushReferrerRequest], metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[store_v1.PushReferrerResponse]: - """Push records with referrer metadata to the Store API. - - Uploads records along with optional artifacts and referrer information. - This is useful for pushing complex objects that include additional - metadata or associated artifacts. - - Args: - req: List of PushReferrerRequest objects containing records and - optional artifacts - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[store_v1.PushReferrerResponse]: List of objects containing the details of pushed artifacts - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the push operation fails - - Example: - >>> requests = [store_v1.PushReferrerRequest(record=record)] - >>> responses = client.push_referrer(requests) - - """ - results: list[store_v1.PushReferrerResponse] = [] - - try: - response = self.store_client.PushReferrer(iter(req), metadata=metadata) - results.extend(response) - except grpc.RpcError as e: - logger.exception("gRPC error during push_referrer: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during push_referrer: %s", e) - msg = f"Failed to push object: {e}" - raise RuntimeError(msg) from e - - return results + return self.store_service.push_referrer(req, metadata=metadata) def pull( self, refs: builtins.list[core_v1.RecordRef], metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[core_v1.Record]: - """Pull records from the Store API by their references. - - Retrieves one or more records from the content store using their - content identifiers (CIDs). - - Args: - refs: List of RecordRef objects containing the CIDs to retrieve - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[core_v1.Record]: List of record objects retrieved from the store - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the pull operation fails - - Example: - >>> refs = [core_v1.RecordRef(cid="QmExample123")] - >>> records = client.pull(refs) - >>> for record in records: - ... print(f"Retrieved record: {record}") - - """ - results: list[core_v1.Record] = [] - - try: - response = self.store_client.Pull(iter(refs), metadata=metadata) - results.extend(response) - except grpc.RpcError as e: - logger.exception("gRPC error during pull: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during pull: %s", e) - msg = f"Failed to pull object: {e}" - raise RuntimeError(msg) from e - - return results + return self.store_service.pull(refs, metadata=metadata) def pull_referrer( self, req: builtins.list[store_v1.PullReferrerRequest], metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[store_v1.PullReferrerResponse]: - """Pull records with referrer metadata from the Store API. - - Retrieves records along with their associated artifacts and referrer - information. This provides access to complex objects that include - additional metadata or associated artifacts. - - Args: - req: List of PullReferrerRequest objects containing records and - optional artifacts for pull operations - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[store_v1.PullReferrerResponse]: List of objects containing the retrieved records - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the pull operation fails - - Example: - >>> requests = [store_v1.PullReferrerRequest(ref=ref)] - >>> responses = client.pull_referrer(requests) - >>> for response in responses: - ... print(f"Retrieved: {response}") - - """ - results: list[store_v1.PullReferrerResponse] = [] - - try: - response = self.store_client.PullReferrer(iter(req), metadata=metadata) - results.extend(response) - except grpc.RpcError as e: - logger.exception("gRPC error during pull_referrer: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during pull_referrer: %s", e) - msg = f"Failed to pull referrer object: {e}" - raise RuntimeError(msg) from e - - return results + return self.store_service.pull_referrer(req, metadata=metadata) def lookup( self, refs: builtins.list[core_v1.RecordRef], metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[core_v1.RecordMeta]: - """Look up metadata for records in the Store API. - - Retrieves metadata information for one or more records without - downloading the full record content. This is useful for checking - if records exist and getting basic information about them. - - Args: - refs: List of RecordRef objects containing the CIDs to look up - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List[core_v1.RecordMeta]: List of objects containing metadata for the records - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the lookup operation fails - - Example: - >>> refs = [core_v1.RecordRef(cid="QmExample123")] - >>> metadatas = client.lookup(refs) - >>> for meta in metadatas: - ... print(f"Record size: {meta.size}") - - """ - results: list[core_v1.RecordMeta] = [] - - try: - response = self.store_client.Lookup(iter(refs), metadata=metadata) - results.extend(response) - except grpc.RpcError as e: - logger.exception("gRPC error during lookup: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during lookup: %s", e) - msg = f"Failed to lookup object: {e}" - raise RuntimeError(msg) from e - - return results + return self.store_service.lookup(refs, metadata=metadata) def delete( self, refs: builtins.list[core_v1.RecordRef], metadata: Sequence[tuple[str, str]] | None = None, ) -> None: - """Delete records from the Store API. - - Permanently removes one or more records from the content store using - their content identifiers (CIDs). This operation cannot be undone. - - Args: - refs: List of RecordRef objects containing the CIDs to delete - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the delete operation fails - - Example: - >>> refs = [core_v1.RecordRef(cid="QmExample123")] - >>> client.delete(refs) - - """ - try: - self.store_client.Delete(iter(refs), metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during delete: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during delete: %s", e) - msg = f"Failed to delete object: {e}" - raise RuntimeError(msg) from e + self.store_service.delete(refs, metadata=metadata) def create_sync( self, req: store_v1.CreateSyncRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> store_v1.CreateSyncResponse: - """Create a new synchronization configuration. - - Creates a new sync configuration that defines how data should be - synchronized between different Directory servers. This allows for - automated data replication and consistency across multiple locations. - - Args: - req: CreateSyncRequest containing the sync configuration details - including source, target, and synchronization parameters - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - store_v1.CreateSyncResponse: Response containing the created sync details - including the sync ID and configuration - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the sync creation fails - - Example: - >>> req = store_v1.CreateSyncRequest() - >>> response = client.create_sync(req) - >>> print(f"Created sync with ID: {response.sync_id}") - - """ - try: - response = self.sync_client.CreateSync(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during create_sync: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during create_sync: %s", e) - msg = f"Failed to create sync: {e}" - raise RuntimeError(msg) from e - - return response + return self.sync_service.create_sync(req, metadata=metadata) def list_syncs( self, req: store_v1.ListSyncsRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[store_v1.ListSyncsItem]: - """List existing synchronization configurations. - - Retrieves a list of all sync configurations that have been created, - with optional filtering and pagination support. This allows you to - monitor and manage multiple synchronization processes. - - Args: - req: ListSyncsRequest containing filtering criteria, pagination options, - and other query parameters - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - list[store_v1.ListSyncsItem]: List of sync configuration items with - their details including ID, name, status, - and configuration parameters - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the list operation fails - - Example: - >>> req = store_v1.ListSyncsRequest(limit=10) - >>> syncs = client.list_syncs(req) - >>> for sync in syncs: - ... print(f"Sync: {sync}") - - """ - results: list[store_v1.ListSyncsItem] = [] - - try: - stream = self.sync_client.ListSyncs(req, metadata=metadata) - results.extend(stream) - except grpc.RpcError as e: - logger.exception("gRPC error during list_syncs: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during list_syncs: %s", e) - msg = f"Failed to list syncs: {e}" - raise RuntimeError(msg) from e - - return results + return self.sync_service.list_syncs(req, metadata=metadata) def get_sync( self, req: store_v1.GetSyncRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> store_v1.GetSyncResponse: - """Retrieve detailed information about a specific synchronization configuration. - - Gets comprehensive details about a specific sync configuration including - its current status, configuration parameters, performance metrics, - and any recent errors or warnings. - - Args: - req: GetSyncRequest containing the sync ID or identifier to retrieve - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - store_v1.GetSyncResponse: Detailed information about the sync configuration - including status, metrics, configuration, and logs - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the get operation fails - - Example: - >>> req = store_v1.GetSyncRequest(sync_id="sync-123") - >>> response = client.get_sync(req) - >>> print(f"Sync status: {response.status}") - >>> print(f"Last update: {response.last_update_time}") - - """ - try: - response = self.sync_client.GetSync(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during get_sync: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during get_sync: %s", e) - msg = f"Failed to get sync: {e}" - raise RuntimeError(msg) from e - - return response + return self.sync_service.get_sync(req, metadata=metadata) def delete_sync( self, req: store_v1.DeleteSyncRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> None: - """Delete a synchronization configuration. - - Permanently removes a sync configuration and stops any ongoing - synchronization processes. This operation cannot be undone and - will halt all data synchronization for the specified configuration. - - Args: - req: DeleteSyncRequest containing the sync ID or identifier to delete - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the delete operation fails - - Example: - >>> req = store_v1.DeleteSyncRequest(sync_id="sync-123") - >>> client.delete_sync(req) - >>> print(f"Sync deleted") - - """ - try: - self.sync_client.DeleteSync(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during delete_sync: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during delete_sync: %s", e) - msg = f"Failed to delete sync: {e}" - raise RuntimeError(msg) from e + self.sync_service.delete_sync(req, metadata=metadata) def listen( self, req: events_v1.ListenRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> grpc.UnaryStreamMultiCallable: - """ - Listen establishes a streaming connection to receive events. - Events are only delivered while the stream is active. - On disconnect, missed events are not recoverable. - - Args: - req: ListenRequest specifies filters for event subscription. - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - A grpc stream which can read and closed. - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the listen operation fails - """ - - try: - stream = self.event_client.Listen(req, metadata=metadata) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.CANCELLED: - logger.exception("gRPC listen stream was canceled: %s", e) - raise - else: - logger.exception("gRPC error during listen: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during listen: %s", e) - msg = f"Failed to listen: {e}" - raise RuntimeError(msg) from e - - return stream + return self.event_service.listen(req, metadata=metadata) def create_publication( self, req: routing_v1.PublishRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> routing_v1.CreatePublicationResponse: - """ - Create publication creates a new publication request that will be processed by the PublicationWorker. - The publication request can specify either a query, a list of specific CIDs, - or all records to be announced to the DHT. - - Args: - req: PublishRequest specifies the record references and queries for publication. - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - CreatePublicationResponse returns the result of creating a publication request. - This includes the publication ID and any relevant metadata. - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the publication operation create fails - """ - try: - response = self.publication_client.CreatePublication(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during create_publication: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during create_publication: %s", e) - msg = f"Failed to create publication: {e}" - raise RuntimeError(msg) from e - - return response + return self.publication_service.create_publication(req, metadata=metadata) def get_publication( self, req: routing_v1.GetPublicationRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> routing_v1.GetPublicationResponse: - """ - GetPublication retrieves details of a specific publication request by its identifier. - This includes the current status and any associated metadata. - - Args: - req: GetPublicationRequest specifies which publication to retrieve by its identifier. - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - GetPublicationResponse contains the full details of a specific publication request. - Includes status, progress information, and any error details if applicable. - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the publication get operation fails - """ - try: - response = self.publication_client.GetPublication(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during get_publication: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during get_publication: %s", e) - msg = f"Failed to get publication: {e}" - raise RuntimeError(msg) from e - - return response + return self.publication_service.get_publication(req, metadata=metadata) def list_publication( self, req: routing_v1.ListPublicationsRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> builtins.list[routing_v1.ListPublicationsItem]: - """ - ListPublications returns a stream of all publication requests in the system. - This allows monitoring of pending, processing, and completed publication requests. - - Args: - req: ListPublicationsRequest contains optional filters for listing publication requests. - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - List of ListPublicationsItem represents a single publication request in the list response. - Contains publication details including ID, status, and creation timestamp. - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the publication list operation fails - """ - - results: list[routing_v1.ListPublicationsItem] = [] - - try: - stream = self.publication_client.ListPublications(req, metadata=metadata) - results.extend(stream) - except grpc.RpcError as e: - logger.exception("gRPC error during list_publication: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during list_publication: %s", e) - msg = f"Failed to list publication: {e}" - raise RuntimeError(msg) from e - - return results + return self.publication_service.list_publication(req, metadata=metadata) def resolve( self, @@ -1222,46 +242,7 @@ def resolve( version: str | None = None, metadata: Sequence[tuple[str, str]] | None = None, ) -> naming_v1.ResolveResponse: - """Resolve a record name to CIDs. - - Resolves a record reference (name with optional version) to content identifiers (CIDs). - When no version is specified, returns all versions sorted by creation time (newest first). - - Args: - name: The name of the record to resolve (e.g., "cisco.com/agent") - version: Optional version to resolve to (e.g., "v1.0.0") - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - naming_v1.ResolveResponse: Response containing the resolved record references - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the resolve operation fails - - Example: - >>> # Resolve latest version - >>> response = client.resolve("cisco.com/agent") - >>> print(f"Latest CID: {response.records[0].cid}") - >>> - >>> # Resolve specific version - >>> response = client.resolve("cisco.com/agent", "v1.0.0") - - """ - try: - req = naming_v1.ResolveRequest(name=name) - if version: - req.version = version - response = self.naming_client.Resolve(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during resolve: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during resolve: %s", e) - msg = f"Failed to resolve name: {e}" - raise RuntimeError(msg) from e - - return response + return self.naming_service.resolve(name, version=version, metadata=metadata) def get_verification_info( self, @@ -1270,500 +251,19 @@ def get_verification_info( version: str | None = None, metadata: Sequence[tuple[str, str]] | None = None, ) -> naming_v1.GetVerificationInfoResponse: - """Get verification info for a record. - - Retrieves the name verification status for a record. Can look up by CID directly - or by name (with optional version) which will be resolved first. - - Args: - cid: Optional CID of the record to check - name: Optional name of the record to check (e.g., "cisco.com/agent") - version: Optional version when looking up by name (e.g., "v1.0.0") - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - naming_v1.GetVerificationInfoResponse: Response containing verification status - - Raises: - grpc.RpcError: If the gRPC call fails (includes InvalidArgument, NotFound, etc.) - RuntimeError: If the operation fails - - Example: - >>> # Check by CID - >>> response = client.get_verification_info(cid="bafyreib...") - >>> - >>> # Check by name (latest version) - >>> response = client.get_verification_info(name="cisco.com/agent") - >>> - >>> # Check by name with specific version - >>> response = client.get_verification_info(name="cisco.com/agent", version="v1.0.0") - - """ - try: - req = naming_v1.GetVerificationInfoRequest() - if cid: - req.cid = cid - if name: - req.name = name - if version: - req.version = version - response = self.naming_client.GetVerificationInfo(req, metadata=metadata) - except grpc.RpcError as e: - logger.exception("gRPC error during get_verification_info: %s", e) - raise - except Exception as e: - logger.exception("Unexpected error during get_verification_info: %s", e) - msg = f"Failed to get verification info: {e}" - raise RuntimeError(msg) from e - - return response + return self.naming_service.get_verification_info( + cid=cid, + name=name, + version=version, + metadata=metadata, + ) def verify( self, req: sign_v1.VerifyRequest, metadata: Sequence[tuple[str, str]] | None = None, ) -> sign_v1.VerifyResponse: - """Verify a cryptographic signature on a record. - - Validates the cryptographic signature of a previously signed record - to ensure its authenticity and integrity. This operation verifies - that the record has not been tampered with since signing. - - - The verification process uses the external dirctl command-line tool - to perform the actual cryptographic operations. - - When useServerVerification is true, uses the server's cached result (from the reconciler). - - Args: - req: VerifyRequest containing the record reference and verification - parameters. The provider can specify either key-based verification - (with a public key) or OIDC-based verification - metadata: Optional gRPC metadata headers as sequence of key-value pairs - - Returns: - VerifyResponse containing the verification result and details - - Raises: - RuntimeError: If the verification operation fails - - Example: - >>> req = sign_v1.VerifyRequest( - ... record_ref=core_v1.RecordRef(cid="QmExample123") - ... ) - >>> response = client.verify(req) - >>> print(f"Signature valid: {response.success}") - - """ - if req.from_server: - if req.record_ref is None or not req.record_ref.cid: - msg = "VerifyRequest.record_ref with cid is required" - raise RuntimeError(msg) - try: - resp = self.sign_client.Verify(req, metadata=metadata or ()) - return resp - except grpc.RpcError as e: - logger.exception("gRPC error during verify: %s", e) - raise RuntimeError(f"Verify failed: {e}") from e - except Exception as e: - logger.exception("Verification failed: %s", e) - raise RuntimeError(f"Verify failed: {e}") from e - - # Client-side verification via dirctl (same as main branch) - fd, output_path = tempfile.mkstemp(suffix=".json", prefix="dirctl-verify-") - os.close(fd) - _output_path = output_path - - if self.config.docker_config: - basename = os.path.basename(output_path) - self.config.docker_config.mounts.append(f"type=bind,src={output_path},dst=/{basename}") - _output_path = basename - - try: - provider: sign_v1.VerifyRequestProvider = req.provider - if provider is None: - self._verify_with_any(req.record_ref, None, _output_path) - elif provider.HasField("key"): - self._verify_with_key(req.record_ref, provider.key, _output_path) - elif provider.HasField("oidc"): - self._verify_with_oidc(req.record_ref, provider.oidc, _output_path) - elif provider.HasField("any"): - self._verify_with_any(req.record_ref, provider.any, _output_path) - else: - self._verify_with_any(req.record_ref, None, _output_path) - - # Read and parse the output file - with open(output_path, "rb") as f: - return self._parse_verify_response(f.read()) - except RuntimeError as e: - msg = f"Failed to verify the object: {e}" - raise RuntimeError(msg) from e - except Exception as e: - logger.exception("Verification operation failed: %s", e) - msg = f"Failed to verify the object: {e}" - raise RuntimeError(msg) from e - finally: - # Clean up the temp file - try: - os.unlink(output_path) - except OSError: - # Ignore cleanup errors - pass - - def _verify_with_key( - self, - record_ref: core_v1.RecordRef, - key_verifier: sign_v1.VerifyWithKey, - output_path: str, - ) -> None: - """Verify a record using a public key. - - This private method handles key-based verification by passing the public key - reference to the dirctl command. The key can be a file path, URL, or KMS URI. - - Args: - record_ref: Reference to the record to verify - key_verifier: VerifyWithKey containing the public key reference - output_path: Path to the output file for verification result - - Raises: - RuntimeError: If any error occurs during verification - - """ - try: - # Build and execute the verification command - # The key reference is passed directly to dirctl which handles - # file paths, URLs, KMS URIs, etc. - command = [ - *self.config.get_dirctl(), - "verify", - record_ref.cid, - "--key", - key_verifier.public_key, - "--output-file", - output_path, - ] - - subprocess.run( - command, - check=True, - capture_output=True, - timeout=60, # 1 minute timeout - ) - - except subprocess.CalledProcessError as e: - msg = f"dirctl verification failed with return code {e.returncode}: {e.stderr.decode('utf-8', errors='ignore')}" - raise RuntimeError(msg) from e - except subprocess.TimeoutExpired as e: - msg = "dirctl verification timed out" - raise RuntimeError(msg) from e - except Exception as e: - msg = f"Unexpected error during key-based verification: {e}" - raise RuntimeError(msg) from e - - def _verify_with_any( - self, - record_ref: core_v1.RecordRef, - any_verifier: sign_v1.VerifyWithAny | None, - output_path: str, - ) -> None: - """Verify a record using any valid signature. - - This private method handles verification of any signature on the record, - with optional OIDC options for OIDC-based signatures. - - Args: - record_ref: Reference to the record to verify - any_verifier: VerifyWithAny containing optional OIDC verification options, - or None for default verification - output_path: Path to the output file for verification result - - Raises: - RuntimeError: If any error occurs during verification - - """ - try: - # Build base command - command = [*self.config.get_dirctl(), "verify", record_ref.cid, "--output-file", output_path] - - # Add OIDC verification options if present - if any_verifier is not None and any_verifier.HasField("oidc_options"): - opts = any_verifier.oidc_options - if opts.tuf_mirror_url: - command.extend(["--tuf-mirror-url", opts.tuf_mirror_url]) - if opts.trusted_root_path: - command.extend(["--trusted-root-path", opts.trusted_root_path]) - if opts.ignore_tlog: - command.append("--ignore-tlog") - if opts.ignore_tsa: - command.append("--ignore-tsa") - if opts.ignore_sct: - command.append("--ignore-sct") - - # Execute the verification command - subprocess.run( - command, - check=True, - capture_output=True, - timeout=60, # 1 minute timeout - ) - - except subprocess.CalledProcessError as e: - msg = f"dirctl verification failed with return code {e.returncode}: {e.stderr.decode('utf-8', errors='ignore')}" - raise RuntimeError(msg) from e - except subprocess.TimeoutExpired as e: - msg = "dirctl verification timed out" - raise RuntimeError(msg) from e - except Exception as e: - msg = f"Unexpected error during verification: {e}" - raise RuntimeError(msg) from e - - def _verify_with_oidc( - self, - record_ref: core_v1.RecordRef, - oidc_verifier: sign_v1.VerifyWithOIDC | None, - output_path: str, - ) -> None: - """Verify a record using OIDC-based verification. - - This private method handles OIDC-based verification by building the appropriate - dirctl command with OIDC parameters and executing it. - - Args: - record_ref: Reference to the record to verify - oidc_verifier: VerifyWithOIDC containing the OIDC verification options, - or None for default verification - output_path: Path to the output file for verification result - - Raises: - RuntimeError: If any error occurs during verification - - """ - try: - # Build base command - command = [*self.config.get_dirctl(), "verify", record_ref.cid, "--output-file", output_path] - - # Add OIDC-specific parameters if provided - if oidc_verifier is not None: - if oidc_verifier.issuer: - command.extend(["--oidc-issuer", oidc_verifier.issuer]) - if oidc_verifier.subject: - command.extend(["--oidc-subject", oidc_verifier.subject]) - - # Add verification options if present - if oidc_verifier.HasField("options"): - opts = oidc_verifier.options - if opts.tuf_mirror_url: - command.extend(["--tuf-mirror-url", opts.tuf_mirror_url]) - if opts.trusted_root_path: - command.extend(["--trusted-root-path", opts.trusted_root_path]) - if opts.ignore_tlog: - command.append("--ignore-tlog") - if opts.ignore_tsa: - command.append("--ignore-tsa") - if opts.ignore_sct: - command.append("--ignore-sct") - - # Execute the verification command - subprocess.run( - command, - check=True, - capture_output=True, - timeout=60, # 1 minute timeout - ) - - except subprocess.CalledProcessError as e: - msg = f"dirctl verification failed with return code {e.returncode}: {e.stderr.decode('utf-8', errors='ignore')}" - raise RuntimeError(msg) from e - except subprocess.TimeoutExpired as e: - msg = "dirctl verification timed out" - raise RuntimeError(msg) from e - except Exception as e: - msg = f"Unexpected error during OIDC verification: {e}" - raise RuntimeError(msg) from e - - def _parse_verify_response(self, output: bytes) -> sign_v1.VerifyResponse: - """Parse the JSON output from dirctl verify command. - - Args: - output: Raw bytes output from the dirctl command - - Returns: - VerifyResponse parsed from the JSON output - - Raises: - RuntimeError: If the output cannot be parsed - - """ - try: - json_str = output.decode("utf-8") - json_data = json.loads(json_str) - - # The CLI outputs the response directly as JSON - response = sign_v1.VerifyResponse() - json_format.ParseDict(json_data, response) - return response - except (json.JSONDecodeError, UnicodeDecodeError) as e: - msg = f"Failed to parse verification response: {e}" - raise RuntimeError(msg) from e - - def sign( - self, - req: sign_v1.SignRequest, - ) -> None: - """Sign a record with a cryptographic signature. - - Creates a cryptographic signature for a record using either a private - key or OIDC-based signing. The signing process uses the external dirctl - command-line tool to perform the actual cryptographic operations. - - Args: - req: SignRequest containing the record reference and signing provider - configuration. The provider can specify either key-based signing - (with a private key) or OIDC-based signing - oidc_client_id: OIDC client identifier for OIDC-based signing. - Defaults to "sigstore" - - Raises: - RuntimeError: If the signing operation fails - - Example: - >>> req = sign_v1.SignRequest( - ... record_ref=core_v1.RecordRef(cid="QmExample123"), - ... provider=sign_v1.SignProvider(key=key_config) - ... ) - >>> client.sign(req) - >>> print(f"Signing completed!") - - """ - try: - if req.provider is None: - msg = "No signing provider specified in the request" - raise RuntimeError(msg) - elif req.provider.HasField("key"): - self._sign_with_key(req.record_ref, req.provider.key) - elif req.provider.HasField("oidc"): - self._sign_with_oidc(req.record_ref, req.provider.oidc) - except RuntimeError as e: - msg = f"Failed to sign the object: {e}" - raise RuntimeError(msg) from e - except Exception as e: - logger.exception("Signing operation failed: %s", e) - msg = f"Failed to sign the object: {e}" - raise RuntimeError(msg) from e - - def _sign_with_key( - self, - record_ref: core_v1.RecordRef, - key_signer: sign_v1.SignWithKey, - ) -> None: - """Sign a record using a private key. - - This private method handles key-based signing by passing the key reference - directly to the dirctl command. The key can be a file path, URL, or KMS URI. - - Args: - record_ref: Reference to the record to sign - key_signer: SignWithKey containing the private key reference and password - - Raises: - RuntimeError: If any error occurs during signing - - """ - try: - # Set up environment with password - # Always set COSIGN_PASSWORD (even if empty) to avoid terminal prompts - shell_env = os.environ.copy() - password = "" - if key_signer.password: - password = key_signer.password.decode("utf-8") - shell_env["COSIGN_PASSWORD"] = password - - # Build and execute the signing command - # The key reference is passed directly to dirctl which handles - # file paths, URLs, KMS URIs, etc. - command = [ - *self.config.get_dirctl(), - "sign", - record_ref.cid, - "--key", - key_signer.private_key, - ] - - subprocess.run( - command, - check=True, - capture_output=True, - env=shell_env, - timeout=60, # 1 minute timeout - ) - - except subprocess.CalledProcessError as e: - msg = f"dirctl signing failed with return code {e.returncode}: {e.stderr.decode('utf-8', errors='ignore')}" - raise RuntimeError(msg) from e - except subprocess.TimeoutExpired as e: - msg = "dirctl signing timed out" - raise RuntimeError(msg) from e - except Exception as e: - msg = f"Unexpected error during key-based signing: {e}" - raise RuntimeError(msg) from e - - def _sign_with_oidc( - self, - record_ref: core_v1.RecordRef, - oidc_signer: sign_v1.SignWithOIDC, - ) -> None: - """Sign a record using OIDC-based authentication. - - This private method handles OIDC-based signing by building the appropriate - dirctl command with OIDC parameters and executing it. - - Args: - req: SignRequest containing the record reference and OIDC provider - - Raises: - RuntimeError: If any other error occurs during signing - - """ - try: - shell_env = os.environ.copy() - - # Build base command - command = [*self.config.get_dirctl(), "sign", record_ref.cid] - - # Add OIDC-specific parameters - if oidc_signer.id_token: - command.extend(["--oidc-token", oidc_signer.id_token]) - if oidc_signer.options.oidc_provider_url: - command.extend(["--oidc-provider-url", oidc_signer.options.oidc_provider_url]) - if oidc_signer.options.oidc_client_id: - command.extend(["--oidc-client-id", oidc_signer.options.oidc_client_id]) - if oidc_signer.options.oidc_client_secret: - command.extend(["--oidc-client-secret", oidc_signer.options.oidc_client_secret]) - if oidc_signer.options.fulcio_url: - command.extend(["--fulcio-url", oidc_signer.options.fulcio_url]) - if oidc_signer.options.rekor_url: - command.extend(["--rekor-url", oidc_signer.options.rekor_url]) - if oidc_signer.options.timestamp_url: - command.extend(["--timestamp-url", oidc_signer.options.timestamp_url]) - if oidc_signer.options.skip_tlog: - command.append("--skip-tlog") - - # Execute the signing command - subprocess.run( - command, - check=True, - capture_output=True, - env=shell_env, - timeout=60, # 1 minute timeout - ) + return self.sign_service.verify(req, metadata=metadata) - except subprocess.CalledProcessError as e: - msg = f"dirctl signing failed with return code {e.returncode}: {e.stderr.decode('utf-8', errors='ignore')}" - raise RuntimeError(msg) from e - except subprocess.TimeoutExpired as e: - msg = "dirctl signing timed out" - raise RuntimeError(msg) from e - except Exception as e: - msg = f"Unexpected error during OIDC signing: {e}" - raise RuntimeError(msg) from e + def sign(self, req: sign_v1.SignRequest) -> None: + self.sign_service.sign(req) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/dirctl/__init__.py b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/__init__.py new file mode 100644 index 0000000..76d3a9b --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/__init__.py @@ -0,0 +1,9 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""dirctl command execution helpers.""" + +from agntcy.dir_sdk.client.dirctl.signing import sign_record +from agntcy.dir_sdk.client.dirctl.verification import verify_record + +__all__ = ["sign_record", "verify_record"] diff --git a/dir-sdk-python/agntcy/dir_sdk/client/dirctl/runner.py b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/runner.py new file mode 100644 index 0000000..e240032 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/runner.py @@ -0,0 +1,70 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Low-level helpers for invoking the dirctl CLI safely.""" + +from __future__ import annotations + +import os +import subprocess +from collections.abc import Sequence + +from agntcy.dir_sdk.client.config import Config, DockerConfig + + +def _copy_docker_config(docker_config: DockerConfig) -> DockerConfig: + return DockerConfig( + dirctl_image=docker_config.dirctl_image, + dirctl_image_tag=docker_config.dirctl_image_tag, + envs=dict(docker_config.envs), + mounts=list(docker_config.mounts), + user=docker_config.user, + ) + + +def build_dirctl_base_command( + config: Config, + env: dict[str, str] | None = None, + extra_mounts: Sequence[str] | None = None, +) -> list[str]: + extra_mounts = list(extra_mounts or []) + if config.dirctl_path: + return [config.dirctl_path] + if config.docker_config is None: + msg = "Either dirctl_path or docker_config must be configured" + raise RuntimeError(msg) + + docker_config = _copy_docker_config(config.docker_config) + docker_config.envs.update(env) + docker_config.mounts.extend(extra_mounts) + return docker_config.get_commands() + + +def run_dirctl( + config: Config, + args: Sequence[str], + *, + env: dict[str, str] | None = None, + timeout: int = 60, + extra_mounts: Sequence[str] | None = None, +) -> None: + command = [*build_dirctl_base_command(config, env=env, extra_mounts=extra_mounts), *args] + shell_env = os.environ.copy() + if env: + shell_env.update(env) + + try: + subprocess.run( + command, + check=True, + capture_output=True, + env=shell_env, + timeout=timeout, + ) + except subprocess.CalledProcessError as e: + stderr = e.stderr.decode("utf-8", errors="ignore") + msg = f"dirctl command failed with return code {e.returncode}: {stderr}" + raise RuntimeError(msg) from e + except subprocess.TimeoutExpired as e: + msg = "dirctl command timed out" + raise RuntimeError(msg) from e diff --git a/dir-sdk-python/agntcy/dir_sdk/client/dirctl/signing.py b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/signing.py new file mode 100644 index 0000000..0d1dfd1 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/signing.py @@ -0,0 +1,64 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Signing helpers for dirctl-backed signature creation.""" + +from __future__ import annotations + +from agntcy.dir_sdk.client.config import Config +from agntcy.dir_sdk.client.dirctl.runner import run_dirctl +from agntcy.dir_sdk.models import core_v1, sign_v1 + + +def sign_record(config: Config, req: sign_v1.SignRequest) -> None: + if req.provider.HasField("key"): + _sign_with_key(config, req.record_ref, req.provider.key) + return + elif req.provider.HasField("oidc"): + _sign_with_oidc(config, req.record_ref, req.provider.oidc) + return + else: + msg = "Unsupported signing provider in request" + raise RuntimeError(msg) + + +def _sign_with_key( + config: Config, + record_ref: core_v1.RecordRef, + key_signer: sign_v1.SignWithKey, +) -> None: + password = "" + if key_signer.password: + password = key_signer.password.decode("utf-8") + run_dirctl( + config, + ["sign", record_ref.cid, "--key", key_signer.private_key], + env={"COSIGN_PASSWORD": password, + "DIRECTORY_CLIENT_SERVER_ADDRESS": config.server_address}, + ) + + +def _sign_with_oidc( + config: Config, + record_ref: core_v1.RecordRef, + oidc_signer: sign_v1.SignWithOIDC, +) -> None: + command = ["sign", record_ref.cid] + if oidc_signer.id_token: + command.extend(["--oidc-token", oidc_signer.id_token]) + if oidc_signer.options.oidc_provider_url: + command.extend(["--oidc-provider-url", oidc_signer.options.oidc_provider_url]) + if oidc_signer.options.oidc_client_id: + command.extend(["--oidc-client-id", oidc_signer.options.oidc_client_id]) + if oidc_signer.options.oidc_client_secret: + command.extend(["--oidc-client-secret", oidc_signer.options.oidc_client_secret]) + if oidc_signer.options.fulcio_url: + command.extend(["--fulcio-url", oidc_signer.options.fulcio_url]) + if oidc_signer.options.rekor_url: + command.extend(["--rekor-url", oidc_signer.options.rekor_url]) + if oidc_signer.options.timestamp_url: + command.extend(["--timestamp-url", oidc_signer.options.timestamp_url]) + if oidc_signer.options.skip_tlog: + command.append("--skip-tlog") + + run_dirctl(config, command, env={"DIRECTORY_CLIENT_SERVER_ADDRESS": config.server_address}) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/dirctl/verification.py b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/verification.py new file mode 100644 index 0000000..6811404 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/dirctl/verification.py @@ -0,0 +1,138 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Verification helpers for dirctl-backed signature validation.""" + +from __future__ import annotations + +import json +import os +import tempfile + +from google.protobuf import json_format + +from agntcy.dir_sdk.client.config import Config +from agntcy.dir_sdk.client.dirctl.runner import run_dirctl +from agntcy.dir_sdk.models import core_v1, sign_v1 + + +def verify_record(config: Config, req: sign_v1.VerifyRequest) -> sign_v1.VerifyResponse: + if req.record_ref is None or not req.record_ref.cid: + msg = "VerifyRequest.record_ref with cid is required" + raise RuntimeError(msg) + + fd, output_path = tempfile.mkstemp(suffix=".json", prefix="dirctl-verify-") + os.close(fd) + try: + _run_verify(config, req, output_path) + return parse_verify_response(output_path) + finally: + try: + os.unlink(output_path) + except OSError: + pass + + +def _run_verify(config: Config, req: sign_v1.VerifyRequest, output_path: str) -> None: + extra_mounts: list[str] = [] + effective_output_path = output_path + if config.docker_config: + basename = os.path.basename(output_path) + extra_mounts.append(f"type=bind,src={output_path},dst=/{basename}") + effective_output_path = f"/{basename}" + + provider = req.provider + + if provider.HasField("key"): + _verify_with_key(config, req.record_ref, provider.key, effective_output_path, extra_mounts=extra_mounts) + elif provider.HasField("oidc"): + _verify_with_oidc(config, req.record_ref, provider.oidc, effective_output_path, extra_mounts=extra_mounts) + elif provider.HasField("any"): + _verify_with_any(config, req.record_ref, provider.any, effective_output_path, extra_mounts=extra_mounts) + else: + msg = "Unsupported verification provider in request" + raise RuntimeError(msg) + + +def _verify_with_key( + config: Config, + record_ref: core_v1.RecordRef, + key_verifier: sign_v1.VerifyWithKey, + output_path: str, + *, + extra_mounts: list[str], +) -> None: + run_dirctl( + config, + ["verify", record_ref.cid, "--key", key_verifier.public_key, "--output-file", output_path], + extra_mounts=extra_mounts, + env={"DIRECTORY_CLIENT_SERVER_ADDRESS": config.server_address}, + ) + + +def _verify_with_any( + config: Config, + record_ref: core_v1.RecordRef, + any_verifier: sign_v1.VerifyWithAny | None, + output_path: str, + *, + extra_mounts: list[str], +) -> None: + command = ["verify", record_ref.cid, "--output-file", output_path] + if any_verifier is not None and any_verifier.HasField("oidc_options"): + opts = any_verifier.oidc_options + if opts.tuf_mirror_url: + command.extend(["--tuf-mirror-url", opts.tuf_mirror_url]) + if opts.trusted_root_path: + command.extend(["--trusted-root-path", opts.trusted_root_path]) + if opts.ignore_tlog: + command.append("--ignore-tlog") + if opts.ignore_tsa: + command.append("--ignore-tsa") + if opts.ignore_sct: + command.append("--ignore-sct") + run_dirctl(config, command, extra_mounts=extra_mounts, env={"DIRECTORY_CLIENT_SERVER_ADDRESS": config.server_address}) + + +def _verify_with_oidc( + config: Config, + record_ref: core_v1.RecordRef, + oidc_verifier: sign_v1.VerifyWithOIDC | None, + output_path: str, + *, + extra_mounts: list[str], +) -> None: + command = ["verify", record_ref.cid, "--output-file", output_path] + if oidc_verifier is not None: + if oidc_verifier.issuer: + command.extend(["--oidc-issuer", oidc_verifier.issuer]) + if oidc_verifier.subject: + command.extend(["--oidc-subject", oidc_verifier.subject]) + if oidc_verifier.HasField("options"): + opts = oidc_verifier.options + if opts.tuf_mirror_url: + command.extend(["--tuf-mirror-url", opts.tuf_mirror_url]) + if opts.trusted_root_path: + command.extend(["--trusted-root-path", opts.trusted_root_path]) + if opts.ignore_tlog: + command.append("--ignore-tlog") + if opts.ignore_tsa: + command.append("--ignore-tsa") + if opts.ignore_sct: + command.append("--ignore-sct") + run_dirctl(config, command, extra_mounts=extra_mounts, env={"DIRECTORY_CLIENT_SERVER_ADDRESS": config.server_address}) + + +def parse_verify_response(output_path: str) -> sign_v1.VerifyResponse: + try: + with open(output_path, "rb") as f: + output = f.read().decode("utf-8") + + json_data = json.loads(output) + response = sign_v1.VerifyResponse() + json_format.ParseDict(json_data, response) + + return response + except (json.JSONDecodeError, UnicodeDecodeError) as e: + msg = f"Failed to parse verification response: {e}" + raise RuntimeError(msg) from e diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/__init__.py b/dir-sdk-python/agntcy/dir_sdk/client/services/__init__.py new file mode 100644 index 0000000..9687f55 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/__init__.py @@ -0,0 +1,24 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Service-layer wrappers around generated gRPC stubs.""" + +from agntcy.dir_sdk.client.services.events import EventService +from agntcy.dir_sdk.client.services.naming import NamingService +from agntcy.dir_sdk.client.services.publication import PublicationService +from agntcy.dir_sdk.client.services.routing import RoutingService +from agntcy.dir_sdk.client.services.search import SearchService +from agntcy.dir_sdk.client.services.signing import SignService +from agntcy.dir_sdk.client.services.store import StoreService +from agntcy.dir_sdk.client.services.sync import SyncService + +__all__ = [ + "EventService", + "NamingService", + "PublicationService", + "RoutingService", + "SearchService", + "SignService", + "StoreService", + "SyncService", +] diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/base.py b/dir-sdk-python/agntcy/dir_sdk/client/services/base.py new file mode 100644 index 0000000..f42acc5 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/base.py @@ -0,0 +1,42 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Shared service-layer helpers for RPC invocation and error mapping.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable +from typing import TypeVar + +import grpc + +T = TypeVar("T") + + +class RpcServiceBase: + def __init__(self, logger: logging.Logger) -> None: + self._logger = logger + + def _invoke(self, op_name: str, error_message: str, call: Callable[[], T]) -> T: + try: + return call() + except grpc.RpcError as e: + self._logger.exception("gRPC error during %s: %s", op_name, e) + raise + except Exception as e: + self._logger.exception("Unexpected error during %s: %s", op_name, e) + msg = f"{error_message}: {e}" + raise RuntimeError(msg) from e + + def _collect_stream( + self, + op_name: str, + error_message: str, + stream_call: Callable[[], Iterable[T]], + ) -> list[T]: + return self._invoke( + op_name, + error_message, + lambda: list(stream_call()), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/events.py b/dir-sdk-python/agntcy/dir_sdk/client/services/events.py new file mode 100644 index 0000000..5c9c454 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/events.py @@ -0,0 +1,37 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Events service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import grpc + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import events_v1 + + +class EventService(RpcServiceBase): + def __init__(self, event_client: events_v1.EventServiceStub, logger) -> None: + super().__init__(logger) + self._event_client = event_client + + def listen( + self, + req: events_v1.ListenRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> grpc.UnaryStreamMultiCallable: + try: + return self._event_client.Listen(req, metadata=metadata) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.CANCELLED: + self._logger.exception("gRPC listen stream was canceled: %s", e) + else: + self._logger.exception("gRPC error during listen: %s", e) + raise + except Exception as e: + self._logger.exception("Unexpected error during listen: %s", e) + msg = f"Failed to listen: {e}" + raise RuntimeError(msg) from e diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/naming.py b/dir-sdk-python/agntcy/dir_sdk/client/services/naming.py new file mode 100644 index 0000000..5b7cb1e --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/naming.py @@ -0,0 +1,54 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Naming service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import naming_v1 + + +class NamingService(RpcServiceBase): + def __init__(self, naming_client: naming_v1.NamingServiceStub, logger) -> None: + super().__init__(logger) + self._naming_client = naming_client + + def resolve( + self, + name: str, + version: str | None = None, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> naming_v1.ResolveResponse: + def call(): + req = naming_v1.ResolveRequest(name=name) + if version: + req.version = version + return self._naming_client.Resolve(req, metadata=metadata) + + return self._invoke("resolve", "Failed to resolve name", call) + + def get_verification_info( + self, + cid: str | None = None, + name: str | None = None, + version: str | None = None, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> naming_v1.GetVerificationInfoResponse: + def call(): + req = naming_v1.GetVerificationInfoRequest() + if cid: + req.cid = cid + if name: + req.name = name + if version: + req.version = version + return self._naming_client.GetVerificationInfo(req, metadata=metadata) + + return self._invoke( + "get_verification_info", + "Failed to get verification info", + call, + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/publication.py b/dir-sdk-python/agntcy/dir_sdk/client/services/publication.py new file mode 100644 index 0000000..a1bbf38 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/publication.py @@ -0,0 +1,54 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Publication service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import routing_v1 + + +class PublicationService(RpcServiceBase): + def __init__( + self, + publication_client: routing_v1.PublicationServiceStub, + logger, + ) -> None: + super().__init__(logger) + self._publication_client = publication_client + + def create_publication( + self, + req: routing_v1.PublishRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> routing_v1.CreatePublicationResponse: + return self._invoke( + "create_publication", + "Failed to create publication", + lambda: self._publication_client.CreatePublication(req, metadata=metadata), + ) + + def get_publication( + self, + req: routing_v1.GetPublicationRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> routing_v1.GetPublicationResponse: + return self._invoke( + "get_publication", + "Failed to get publication", + lambda: self._publication_client.GetPublication(req, metadata=metadata), + ) + + def list_publication( + self, + req: routing_v1.ListPublicationsRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> list[routing_v1.ListPublicationsItem]: + return self._collect_stream( + "list_publication", + "Failed to list publication", + lambda: self._publication_client.ListPublications(req, metadata=metadata), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/routing.py b/dir-sdk-python/agntcy/dir_sdk/client/services/routing.py new file mode 100644 index 0000000..baa4714 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/routing.py @@ -0,0 +1,50 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Routing service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import routing_v1 + + +class RoutingService(RpcServiceBase): + def __init__(self, routing_client: routing_v1.RoutingServiceStub, logger) -> None: + super().__init__(logger) + self._routing_client = routing_client + + def publish( + self, + req: routing_v1.PublishRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> None: + self._invoke( + "publish", + "Failed to publish object", + lambda: self._routing_client.Publish(req, metadata=metadata), + ) + + def list( + self, + req: routing_v1.ListRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> list[routing_v1.ListResponse]: + return self._collect_stream( + "list", + "Failed to list objects", + lambda: self._routing_client.List(req, metadata=metadata), + ) + + def unpublish( + self, + req: routing_v1.UnpublishRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> None: + self._invoke( + "unpublish", + "Failed to unpublish object", + lambda: self._routing_client.Unpublish(req, metadata=metadata), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/search.py b/dir-sdk-python/agntcy/dir_sdk/client/services/search.py new file mode 100644 index 0000000..5bc88e1 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/search.py @@ -0,0 +1,39 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Search service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import search_v1 + + +class SearchService(RpcServiceBase): + def __init__(self, search_client: search_v1.SearchServiceStub, logger) -> None: + super().__init__(logger) + self._search_client = search_client + + def search_cids( + self, + req: search_v1.SearchCIDsRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> list[search_v1.SearchCIDsResponse]: + return self._collect_stream( + "search", + "Failed to search CIDs", + lambda: self._search_client.SearchCIDs(req, metadata=metadata), + ) + + def search_records( + self, + req: search_v1.SearchRecordsRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> list[search_v1.SearchRecordsResponse]: + return self._collect_stream( + "search", + "Failed to search records", + lambda: self._search_client.SearchRecords(req, metadata=metadata), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/signing.py b/dir-sdk-python/agntcy/dir_sdk/client/services/signing.py new file mode 100644 index 0000000..1841cfe --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/signing.py @@ -0,0 +1,55 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Sign/verify service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import grpc + +from agntcy.dir_sdk.client.config import Config +from agntcy.dir_sdk.client.dirctl.signing import sign_record +from agntcy.dir_sdk.client.dirctl.verification import verify_record +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import sign_v1 + + +class SignService(RpcServiceBase): + def __init__(self, config: Config, sign_client: sign_v1.SignServiceStub, logger) -> None: + super().__init__(logger) + self._config = config + self._sign_client = sign_client + + def verify( + self, + req: sign_v1.VerifyRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> sign_v1.VerifyResponse: + if req.from_server: + if req.record_ref is None or not req.record_ref.cid: + msg = "VerifyRequest.record_ref with cid is required" + raise RuntimeError(msg) + try: + return self._sign_client.Verify(req, metadata=metadata or ()) + except grpc.RpcError as e: + self._logger.exception("gRPC error during verify: %s", e) + raise RuntimeError(f"Verify failed: {e}") from e + except Exception as e: + self._logger.exception("Verification failed: %s", e) + raise RuntimeError(f"Verify failed: {e}") from e + try: + return verify_record(self._config, req) + except Exception as e: + self._logger.exception("Verification operation failed: %s", e) + raise RuntimeError(f"Failed to verify the object: {e}") from e + + def sign(self, req: sign_v1.SignRequest) -> None: + try: + sign_record(self._config, req) + except RuntimeError as e: + raise RuntimeError(f"Failed to sign the object: {e}") from e + except Exception as e: + self._logger.exception("Signing operation failed: %s", e) + raise RuntimeError(f"Failed to sign the object: {e}") from e diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/store.py b/dir-sdk-python/agntcy/dir_sdk/client/services/store.py new file mode 100644 index 0000000..addc58c --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/store.py @@ -0,0 +1,84 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Store service wrappers.""" + +from __future__ import annotations + +import builtins +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import core_v1, store_v1 + + +class StoreService(RpcServiceBase): + def __init__(self, store_client: store_v1.StoreServiceStub, logger) -> None: + super().__init__(logger) + self._store_client = store_client + + def push( + self, + records: builtins.list[core_v1.Record], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> builtins.list[core_v1.RecordRef]: + return self._collect_stream( + "push", + "Failed to push object", + lambda: self._store_client.Push(iter(records), metadata=metadata), + ) + + def push_referrer( + self, + req: builtins.list[store_v1.PushReferrerRequest], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> builtins.list[store_v1.PushReferrerResponse]: + return self._collect_stream( + "push_referrer", + "Failed to push object", + lambda: self._store_client.PushReferrer(iter(req), metadata=metadata), + ) + + def pull( + self, + refs: builtins.list[core_v1.RecordRef], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> builtins.list[core_v1.Record]: + return self._collect_stream( + "pull", + "Failed to pull object", + lambda: self._store_client.Pull(iter(refs), metadata=metadata), + ) + + def pull_referrer( + self, + req: builtins.list[store_v1.PullReferrerRequest], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> builtins.list[store_v1.PullReferrerResponse]: + return self._collect_stream( + "pull_referrer", + "Failed to pull referrer object", + lambda: self._store_client.PullReferrer(iter(req), metadata=metadata), + ) + + def lookup( + self, + refs: builtins.list[core_v1.RecordRef], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> builtins.list[core_v1.RecordMeta]: + return self._collect_stream( + "lookup", + "Failed to lookup object", + lambda: self._store_client.Lookup(iter(refs), metadata=metadata), + ) + + def delete( + self, + refs: builtins.list[core_v1.RecordRef], + metadata: Sequence[tuple[str, str]] | None = None, + ) -> None: + self._invoke( + "delete", + "Failed to delete object", + lambda: self._store_client.Delete(iter(refs), metadata=metadata), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/services/sync.py b/dir-sdk-python/agntcy/dir_sdk/client/services/sync.py new file mode 100644 index 0000000..353af51 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/services/sync.py @@ -0,0 +1,61 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Sync service wrappers.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from agntcy.dir_sdk.client.services.base import RpcServiceBase +from agntcy.dir_sdk.models import store_v1 + + +class SyncService(RpcServiceBase): + def __init__(self, sync_client: store_v1.SyncServiceStub, logger) -> None: + super().__init__(logger) + self._sync_client = sync_client + + def create_sync( + self, + req: store_v1.CreateSyncRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> store_v1.CreateSyncResponse: + return self._invoke( + "create_sync", + "Failed to create sync", + lambda: self._sync_client.CreateSync(req, metadata=metadata), + ) + + def list_syncs( + self, + req: store_v1.ListSyncsRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> list[store_v1.ListSyncsItem]: + return self._collect_stream( + "list_syncs", + "Failed to list syncs", + lambda: self._sync_client.ListSyncs(req, metadata=metadata), + ) + + def get_sync( + self, + req: store_v1.GetSyncRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> store_v1.GetSyncResponse: + return self._invoke( + "get_sync", + "Failed to get sync", + lambda: self._sync_client.GetSync(req, metadata=metadata), + ) + + def delete_sync( + self, + req: store_v1.DeleteSyncRequest, + metadata: Sequence[tuple[str, str]] | None = None, + ) -> None: + self._invoke( + "delete_sync", + "Failed to delete sync", + lambda: self._sync_client.DeleteSync(req, metadata=metadata), + ) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/transport/__init__.py b/dir-sdk-python/agntcy/dir_sdk/client/transport/__init__.py new file mode 100644 index 0000000..37bcdad --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/transport/__init__.py @@ -0,0 +1,16 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Transport layer for gRPC channel and interceptors.""" + +from agntcy.dir_sdk.client.transport.channels import create_grpc_channel +from agntcy.dir_sdk.client.transport.interceptors import ( + BearerAuthInterceptor, + JWTAuthInterceptor, +) + +__all__ = [ + "BearerAuthInterceptor", + "JWTAuthInterceptor", + "create_grpc_channel", +] diff --git a/dir-sdk-python/agntcy/dir_sdk/client/transport/channels.py b/dir-sdk-python/agntcy/dir_sdk/client/transport/channels.py new file mode 100644 index 0000000..f128ad0 --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/transport/channels.py @@ -0,0 +1,179 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""gRPC channel factory helpers.""" + +from __future__ import annotations + +from pathlib import Path + +import grpc +from cryptography.hazmat.primitives import serialization +from spiffe import WorkloadApiClient, X509Source + +from agntcy.dir_sdk.client.config import Config +from agntcy.dir_sdk.client.auth.oauth_pkce import OAuthTokenHolder +from agntcy.dir_sdk.client.transport.interceptors import ( + BearerAuthInterceptor, + JWTAuthInterceptor, +) + + +def grpc_channel_options(config: Config) -> list[tuple[str, str]]: + server_name = config.tls_server_name.strip() + if not server_name: + return [] + return [ + ("grpc.ssl_target_name_override", server_name), + ("grpc.default_authority", server_name), + ] + + +def create_grpc_channel( + config: Config, + oauth_holder: OAuthTokenHolder | None = None, +) -> grpc.Channel: + if config.auth_mode == "": + return grpc.insecure_channel(config.server_address) + if config.auth_mode == "jwt": + return create_jwt_channel(config) + if config.auth_mode == "x509": + return create_x509_channel(config) + if config.auth_mode == "tls": + return create_tls_channel(config) + if config.auth_mode == "oidc": + return create_oauth_pkce_channel(config, oauth_holder) + msg = f"Unsupported auth mode: {config.auth_mode}" + raise ValueError(msg) + + +def create_x509_channel(config: Config) -> grpc.Channel: + if config.spiffe_socket_path == "": + msg = "SPIFFE socket path is required for X.509 authentication" + raise ValueError(msg) + + workload_client = WorkloadApiClient(socket_path=config.spiffe_socket_path) + x509_src = X509Source( + workload_api_client=workload_client, + socket_path=config.spiffe_socket_path, + timeout_in_seconds=60, + ) + + root_ca = b"" + for bundle in x509_src.bundles: + for authority in bundle.x509_authorities: + root_ca += authority.public_bytes(encoding=serialization.Encoding.PEM) + + private_key = x509_src.svid.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_leaf = x509_src.svid.leaf.public_bytes(encoding=serialization.Encoding.PEM) + + credentials = grpc.ssl_channel_credentials( + root_certificates=root_ca, + private_key=private_key, + certificate_chain=public_leaf, + ) + return grpc.secure_channel( + target=config.server_address, + credentials=credentials, + options=grpc_channel_options(config), + ) + + +def create_jwt_channel(config: Config) -> grpc.Channel: + if config.spiffe_socket_path == "": + msg = "SPIFFE socket path is required for JWT authentication" + raise ValueError(msg) + if config.jwt_audience == "": + msg = "JWT audience is required for JWT authentication" + raise ValueError(msg) + + workload_client = WorkloadApiClient(socket_path=config.spiffe_socket_path) + x509_source = X509Source( + workload_api_client=workload_client, + socket_path=config.spiffe_socket_path, + timeout_in_seconds=60, + ) + try: + root_ca = b"" + for bundle in x509_source.bundles: + for authority in bundle.x509_authorities: + root_ca += authority.public_bytes(encoding=serialization.Encoding.PEM) + if not root_ca: + msg = "Failed to fetch X.509 bundle from SPIRE: no bundles returned" + raise RuntimeError(msg) + + credentials = grpc.ssl_channel_credentials(root_certificates=root_ca) + channel = grpc.secure_channel( + target=config.server_address, + credentials=credentials, + options=grpc_channel_options(config), + ) + finally: + x509_source.close() + + jwt_interceptor = JWTAuthInterceptor( + socket_path=config.spiffe_socket_path, + audience=config.jwt_audience, + ) + return grpc.intercept_channel(channel, jwt_interceptor) + + +def create_tls_channel(config: Config) -> grpc.Channel: + if not config.tls_ca_file: + msg = "TLS CA file is required for TLS authentication" + raise ValueError(msg) + if not config.tls_cert_file: + msg = "TLS certificate file is required for TLS authentication" + raise ValueError(msg) + if not config.tls_key_file: + msg = "TLS key file is required for TLS authentication" + raise ValueError(msg) + + try: + root_ca = Path(config.tls_ca_file).read_bytes() + cert_chain = Path(config.tls_cert_file).read_bytes() + private_key = Path(config.tls_key_file).read_bytes() + except OSError as e: + msg = f"Failed to read TLS files: {e}" + raise RuntimeError(msg) from e + + credentials = grpc.ssl_channel_credentials( + root_certificates=root_ca, + private_key=private_key, + certificate_chain=cert_chain, + ) + return grpc.secure_channel( + target=config.server_address, + credentials=credentials, + options=grpc_channel_options(config), + ) + + +def create_oauth_pkce_channel( + config: Config, + oauth_holder: OAuthTokenHolder | None, +) -> grpc.Channel: + if oauth_holder is None: + msg = "OAuth token holder not initialized" + raise RuntimeError(msg) + + root_ca = None + if config.tls_ca_file: + try: + root_ca = Path(config.tls_ca_file).read_bytes() + except OSError as e: + msg = f"Failed to read TLS CA file: {e}" + raise RuntimeError(msg) from e + + credentials = grpc.ssl_channel_credentials(root_certificates=root_ca) + channel = grpc.secure_channel( + target=config.server_address, + credentials=credentials, + options=grpc_channel_options(config), + ) + bearer = BearerAuthInterceptor(oauth_holder.get_access_token) + return grpc.intercept_channel(channel, bearer) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/transport/interceptors.py b/dir-sdk-python/agntcy/dir_sdk/client/transport/interceptors.py new file mode 100644 index 0000000..300554f --- /dev/null +++ b/dir-sdk-python/agntcy/dir_sdk/client/transport/interceptors.py @@ -0,0 +1,92 @@ +# Copyright AGNTCY Contributors (https://github.com/agntcy) +# SPDX-License-Identifier: Apache-2.0 + +"""Authentication interceptors for gRPC client channels.""" + +from __future__ import annotations + +from collections.abc import Callable + +import grpc +from spiffe import WorkloadApiClient + + +def _build_call_details(client_call_details, metadata: list[tuple[str, str]]): + return grpc._interceptor._ClientCallDetails( + method=client_call_details.method, + timeout=client_call_details.timeout, + metadata=metadata, + credentials=client_call_details.credentials, + wait_for_ready=client_call_details.wait_for_ready, + compression=client_call_details.compression, + ) + + +class JWTAuthInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + """Add SPIFFE JWT-SVID authorization metadata to outgoing requests.""" + + def __init__(self, socket_path: str, audience: str) -> None: + self._audience = audience + self._workload_client = WorkloadApiClient(socket_path=socket_path) + + def _get_jwt_token(self) -> str: + try: + jwt_svid = self._workload_client.fetch_jwt_svid(audience=[self._audience]) + if jwt_svid and jwt_svid.token: + return jwt_svid.token + msg = "Failed to fetch JWT-SVID: empty token" + raise RuntimeError(msg) + except Exception as e: + msg = f"Failed to fetch JWT-SVID: {e}" + raise RuntimeError(msg) from e + + def _add_jwt_metadata(self, client_call_details): + metadata = list(client_call_details.metadata or []) + metadata.append(("authorization", f"Bearer {self._get_jwt_token()}")) + return _build_call_details(client_call_details, metadata) + + def intercept_unary_unary(self, continuation, client_call_details, request): + return continuation(self._add_jwt_metadata(client_call_details), request) + + def intercept_unary_stream(self, continuation, client_call_details, request): + return continuation(self._add_jwt_metadata(client_call_details), request) + + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + return continuation(self._add_jwt_metadata(client_call_details), request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + return continuation(self._add_jwt_metadata(client_call_details), request_iterator) + + +class BearerAuthInterceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + """Add static bearer authorization metadata to outgoing requests.""" + + def __init__(self, token_supplier: Callable[[], str]) -> None: + self._token_supplier = token_supplier + + def _add_bearer_metadata(self, client_call_details): + metadata = list(client_call_details.metadata or []) + metadata.append(("authorization", f"Bearer {self._token_supplier()}")) + return _build_call_details(client_call_details, metadata) + + def intercept_unary_unary(self, continuation, client_call_details, request): + return continuation(self._add_bearer_metadata(client_call_details), request) + + def intercept_unary_stream(self, continuation, client_call_details, request): + return continuation(self._add_bearer_metadata(client_call_details), request) + + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + return continuation(self._add_bearer_metadata(client_call_details), request_iterator) + + def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + return continuation(self._add_bearer_metadata(client_call_details), request_iterator) diff --git a/dir-sdk-python/agntcy/dir_sdk/client/test_client.py b/dir-sdk-python/agntcy/dir_sdk/tests/test_client.py similarity index 98% rename from dir-sdk-python/agntcy/dir_sdk/client/test_client.py rename to dir-sdk-python/agntcy/dir_sdk/tests/test_client.py index 40582bb..11bd104 100644 --- a/dir-sdk-python/agntcy/dir_sdk/client/test_client.py +++ b/dir-sdk-python/agntcy/dir_sdk/tests/test_client.py @@ -282,14 +282,16 @@ def test_sign_and_verify(self) -> None: provider_url = shell_env.get("OIDC_PROVIDER_URL", "") client_id = shell_env.get("OIDC_CLIENT_ID", "sigstore") - oidc_options = sign_v1.SignOptionsOIDC(oidc_provider_url=provider_url, oidc_client_id=client_id) - oidc_provider = sign_v1.SignWithOIDC(id_token=token, options=oidc_options) + sign_oidc_options = sign_v1.SignOptionsOIDC(oidc_provider_url=provider_url, oidc_client_id=client_id) + oidc_provider = sign_v1.SignWithOIDC(id_token=token, options=sign_oidc_options) request_oidc_provider = sign_v1.SignRequestProvider(oidc=oidc_provider) oidc_request = sign_v1.SignRequest( record_ref=record_refs[1], provider=request_oidc_provider, ) + verify_oidc_options = sign_v1.VerifyOptionsOIDC() + try: # Sign and verify using Key signing self.client.sign(key_request) @@ -305,7 +307,7 @@ def test_sign_and_verify(self) -> None: verify_index = 0 for ref in record_refs: - response = self.client.verify(sign_v1.VerifyRequest(record_ref=ref)) + response = self.client.verify(sign_v1.VerifyRequest(record_ref=ref, provider=sign_v1.VerifyRequestProvider(any=sign_v1.VerifyWithAny(oidc_options=verify_oidc_options)))) assert response.success is True diff --git a/dir-sdk-python/agntcy/dir_sdk/client/test_oidc_auth.py b/dir-sdk-python/agntcy/dir_sdk/tests/test_oidc_auth.py similarity index 79% rename from dir-sdk-python/agntcy/dir_sdk/client/test_oidc_auth.py rename to dir-sdk-python/agntcy/dir_sdk/tests/test_oidc_auth.py index 0b58c47..f755a9c 100644 --- a/dir-sdk-python/agntcy/dir_sdk/client/test_oidc_auth.py +++ b/dir-sdk-python/agntcy/dir_sdk/tests/test_oidc_auth.py @@ -12,8 +12,9 @@ from datetime import UTC, datetime, timedelta from agntcy.dir_sdk.client import Client, Config -from agntcy.dir_sdk.client.oauth_pkce import OAuthTokenHolder -from agntcy.dir_sdk.client.token_cache import TOKEN_CACHE_FILE, TokenCache +from agntcy.dir_sdk.client.auth.oauth_pkce import OAuthTokenHolder +from agntcy.dir_sdk.client.auth.token_cache import TOKEN_CACHE_FILE, TokenCache +from agntcy.dir_sdk.client.transport.channels import create_oauth_pkce_channel class OIDCAuthConfigTests(unittest.TestCase): @@ -71,19 +72,20 @@ def test_constructor_uses_preissued_token_without_pkce(self) -> None: with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.fetch_openid_configuration", + "agntcy.dir_sdk.client.auth.session.fetch_openid_configuration", ) as fetch_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.run_loopback_pkce_login", + "agntcy.dir_sdk.client.auth.session.run_loopback_pkce_login", ) as login_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.TokenCache.get_valid_token", + "agntcy.dir_sdk.client.auth.session.TokenCache.get_valid_token", return_value=None, ), ): client = Client(config) - self.assertEqual(client._oauth_holder.get_access_token(), "preissued-token") + self.assertTrue(client.has_cached_oauth_token()) + self.assertEqual(client.get_access_token(), "preissued-token") fetch_mock.assert_not_called() login_mock.assert_not_called() @@ -95,20 +97,19 @@ def test_constructor_without_token_does_not_start_pkce(self) -> None: with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.fetch_openid_configuration", + "agntcy.dir_sdk.client.auth.session.fetch_openid_configuration", ) as fetch_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.run_loopback_pkce_login", + "agntcy.dir_sdk.client.auth.session.run_loopback_pkce_login", ) as login_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.TokenCache.get_valid_token", + "agntcy.dir_sdk.client.auth.session.TokenCache.get_valid_token", return_value=None, ), ): client = Client(config) - with self.assertRaisesRegex(RuntimeError, "DIRECTORY_CLIENT_AUTH_TOKEN"): - client._oauth_holder.get_access_token() + self.assertFalse(client.has_cached_oauth_token()) fetch_mock.assert_not_called() login_mock.assert_not_called() @@ -139,15 +140,16 @@ def test_constructor_uses_cached_token_without_pkce(self) -> None: with ( unittest.mock.patch.dict("os.environ", {"XDG_CONFIG_HOME": tmp_dir}, clear=True), unittest.mock.patch( - "agntcy.dir_sdk.client.client.fetch_openid_configuration", + "agntcy.dir_sdk.client.auth.session.fetch_openid_configuration", ) as fetch_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.run_loopback_pkce_login", + "agntcy.dir_sdk.client.auth.session.run_loopback_pkce_login", ) as login_mock, ): client = Client(config) - self.assertEqual(client._oauth_holder.get_access_token(), "cached-token") + self.assertTrue(client.has_cached_oauth_token()) + self.assertEqual(client.get_access_token(), "cached-token") fetch_mock.assert_not_called() login_mock.assert_not_called() @@ -164,14 +166,14 @@ def test_authenticate_oauth_pkce_updates_access_token(self) -> None: with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.fetch_openid_configuration", + "agntcy.dir_sdk.client.auth.session.fetch_openid_configuration", return_value={ "authorization_endpoint": "https://issuer.example.com/auth", "token_endpoint": "https://issuer.example.com/token", }, ) as fetch_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.run_loopback_pkce_login", + "agntcy.dir_sdk.client.auth.session.run_loopback_pkce_login", return_value={ "access_token": "fresh-token", "refresh_token": "ignored-refresh-token", @@ -181,7 +183,7 @@ def test_authenticate_oauth_pkce_updates_access_token(self) -> None: ): client.authenticate_oauth_pkce() - self.assertEqual(client._oauth_holder.get_access_token(), "fresh-token") + self.assertEqual(client.get_access_token(), "fresh-token") fetch_mock.assert_called_once() login_mock.assert_called_once() @@ -198,14 +200,14 @@ def test_authenticate_oauth_pkce_saves_go_compatible_cache_entry(self) -> None: client = Client(config) with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.fetch_openid_configuration", + "agntcy.dir_sdk.client.auth.session.fetch_openid_configuration", return_value={ "authorization_endpoint": "https://issuer.example.com/auth", "token_endpoint": "https://issuer.example.com/token", }, ), unittest.mock.patch( - "agntcy.dir_sdk.client.client.run_loopback_pkce_login", + "agntcy.dir_sdk.client.auth.session.run_loopback_pkce_login", return_value={ "access_token": "fresh-token", "refresh_token": "refresh-token", @@ -228,35 +230,34 @@ def test_authenticate_oauth_pkce_saves_go_compatible_cache_entry(self) -> None: self.assertIsNotNone(cached_token.expires_at) def test_oauth_channel_uses_configured_tls_ca(self) -> None: - client = Client.__new__(Client) - client.config = Config( + config = Config( server_address="directory.example.com:443", auth_mode="oidc", tls_ca_file="", ) - client._oauth_holder = OAuthTokenHolder() - client._oauth_holder.set_tokens("token") + oauth_holder = OAuthTokenHolder() + oauth_holder.set_tokens("token") with tempfile.NamedTemporaryFile() as ca_file: ca_file.write(b"test-ca") ca_file.flush() - client.config.tls_ca_file = ca_file.name + config.tls_ca_file = ca_file.name with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.ssl_channel_credentials", + "agntcy.dir_sdk.client.transport.channels.grpc.ssl_channel_credentials", return_value="creds", ) as creds_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.secure_channel", + "agntcy.dir_sdk.client.transport.channels.grpc.secure_channel", return_value="channel", ) as secure_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.intercept_channel", + "agntcy.dir_sdk.client.transport.channels.grpc.intercept_channel", return_value="intercepted-channel", ) as intercept_mock, ): - channel = client._Client__create_oauth_pkce_channel() + channel = create_oauth_pkce_channel(config, oauth_holder) self.assertEqual(channel, "intercepted-channel") creds_mock.assert_called_once_with(root_certificates=b"test-ca") @@ -268,30 +269,29 @@ def test_oauth_channel_uses_configured_tls_ca(self) -> None: intercept_mock.assert_called_once() def test_oauth_channel_uses_tls_server_name_override(self) -> None: - client = Client.__new__(Client) - client.config = Config( + config = Config( server_address="directory.example.com:443", auth_mode="oidc", tls_server_name="override.example.com", ) - client._oauth_holder = OAuthTokenHolder() - client._oauth_holder.set_tokens("token") + oauth_holder = OAuthTokenHolder() + oauth_holder.set_tokens("token") with ( unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.ssl_channel_credentials", + "agntcy.dir_sdk.client.transport.channels.grpc.ssl_channel_credentials", return_value="creds", ) as creds_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.secure_channel", + "agntcy.dir_sdk.client.transport.channels.grpc.secure_channel", return_value="channel", ) as secure_mock, unittest.mock.patch( - "agntcy.dir_sdk.client.client.grpc.intercept_channel", + "agntcy.dir_sdk.client.transport.channels.grpc.intercept_channel", return_value="intercepted-channel", ) as intercept_mock, ): - channel = client._Client__create_oauth_pkce_channel() + channel = create_oauth_pkce_channel(config, oauth_holder) self.assertEqual(channel, "intercepted-channel") creds_mock.assert_called_once_with(root_certificates=None) diff --git a/examples/example_interactive_oidc.py b/examples/example_interactive_oidc.py index 8f3ccd1..996b78f 100644 --- a/examples/example_interactive_oidc.py +++ b/examples/example_interactive_oidc.py @@ -64,14 +64,9 @@ def build_client() -> Client: ), ) client = Client(config) - holder = getattr(client, "_oauth_holder", None) - if holder is not None: - try: - holder.get_access_token() - print("Using cached OIDC token.") - return client - except RuntimeError: - pass + if client.has_cached_oauth_token(): + print("Using cached OIDC token.") + return client print("No cached OIDC token found. Starting interactive login.") client.authenticate_oauth_pkce() diff --git a/pyproject.toml b/pyproject.toml index c2e67c2..f6c3231 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ packages = ["dir-sdk-python/agntcy"] [tool.pytest.ini_options] pythonpath = ["dir-sdk-python"] +addopts = ["--import-mode=importlib",] [tool.ruff] line-length = 88