Skip to content

Commit d649176

Browse files
authored
Merge pull request #985 from tisnik/lcore-1051-final-updates
LCORE-1051: final updates
2 parents 06686ba + 1b7ddb1 commit d649176

File tree

6 files changed

+49
-41
lines changed

6 files changed

+49
-41
lines changed

src/app/endpoints/a2a.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import uuid
77
from datetime import datetime, timezone
8-
from typing import Annotated, Any, AsyncIterator, MutableMapping
8+
from typing import Annotated, Any, AsyncIterator, MutableMapping, Optional
99

1010
from fastapi import APIRouter, Depends, HTTPException, Request, status
1111
from llama_stack.apis.agents.openai_responses import (
@@ -65,8 +65,8 @@
6565
# Task store and context store are created lazily based on configuration.
6666
# For multi-worker deployments, configure 'a2a_state' with 'sqlite' or 'postgres'
6767
# to share state across workers.
68-
_TASK_STORE: TaskStore | None = None
69-
_CONTEXT_STORE: A2AContextStore | None = None
68+
_TASK_STORE: Optional[TaskStore] = None
69+
_CONTEXT_STORE: Optional[A2AContextStore] = None
7070

7171

7272
async def _get_task_store() -> TaskStore:
@@ -120,7 +120,7 @@ class TaskResultAggregator:
120120
def __init__(self) -> None:
121121
"""Initialize the task result aggregator with default state."""
122122
self._task_state: TaskState = TaskState.working
123-
self._task_status_message: Message | None = None
123+
self._task_status_message: Optional[Message] = None
124124

125125
def process_event(
126126
self, event: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Any
@@ -169,7 +169,7 @@ def task_state(self) -> TaskState:
169169
return self._task_state
170170

171171
@property
172-
def task_status_message(self) -> Message | None:
172+
def task_status_message(self) -> Optional[Message]:
173173
"""Return the current task status message."""
174174
return self._task_status_message
175175

@@ -185,7 +185,7 @@ class A2AAgentExecutor(AgentExecutor):
185185
"""
186186

187187
def __init__(
188-
self, auth_token: str, mcp_headers: dict[str, dict[str, str]] | None = None
188+
self, auth_token: str, mcp_headers: Optional[dict[str, dict[str, str]]] = None
189189
):
190190
"""Initialize the A2A agent executor.
191191
@@ -413,7 +413,7 @@ async def _convert_stream_to_events( # pylint: disable=too-many-branches,too-ma
413413
stream: AsyncIterator[OpenAIResponseObjectStream],
414414
task_id: str,
415415
context_id: str,
416-
conversation_id: str | None,
416+
conversation_id: Optional[str],
417417
) -> AsyncIterator[Any]:
418418
"""Convert Responses API stream chunks to A2A events.
419419

src/app/endpoints/query.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,12 @@ def persist_user_conversation_details(
156156

157157

158158
def evaluate_model_hints(
159-
user_conversation: UserConversation | None,
159+
user_conversation: Optional[UserConversation],
160160
query_request: QueryRequest,
161-
) -> tuple[str | None, str | None]:
161+
) -> tuple[Optional[str], Optional[str]]:
162162
"""Evaluate model hints from user conversation."""
163-
model_id: str | None = query_request.model
164-
provider_id: str | None = query_request.provider
163+
model_id: Optional[str] = query_request.model
164+
provider_id: Optional[str] = query_request.provider
165165

166166
if user_conversation is not None:
167167
if query_request.model is not None:
@@ -271,7 +271,7 @@ async def query_endpoint_handler_base( # pylint: disable=R0914
271271
user_id, _, _skip_userid_check, token = auth
272272

273273
started_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
274-
user_conversation: UserConversation | None = None
274+
user_conversation: Optional[UserConversation] = None
275275
if query_request.conversation_id:
276276
logger.debug(
277277
"Conversation ID specified in query: %s", query_request.conversation_id
@@ -483,7 +483,7 @@ async def query_endpoint_handler(
483483

484484

485485
def select_model_and_provider_id(
486-
models: ModelListResponse, model_id: str | None, provider_id: str | None
486+
models: ModelListResponse, model_id: Optional[str], provider_id: Optional[str]
487487
) -> tuple[str, str, str]:
488488
"""
489489
Select the model ID and provider ID based on the request or available models.
@@ -663,7 +663,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
663663
model_id: str,
664664
query_request: QueryRequest,
665665
token: str,
666-
mcp_headers: dict[str, dict[str, str]] | None = None,
666+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
667667
*,
668668
provider_id: str = "",
669669
) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]:
@@ -859,7 +859,7 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
859859

860860
def get_rag_toolgroups(
861861
vector_db_ids: list[str],
862-
) -> list[Toolgroup] | None:
862+
) -> Optional[list[Toolgroup]]:
863863
"""
864864
Return a list of RAG Tool groups if the given vector DB list is not empty.
865865
@@ -870,7 +870,7 @@ def get_rag_toolgroups(
870870
vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup.
871871
872872
Returns:
873-
list[Toolgroup] | None: A list with a single RAG toolgroup if
873+
Optional[list[Toolgroup]]: A list with a single RAG toolgroup if
874874
vector_db_ids is non-empty; otherwise, None.
875875
"""
876876
return (

src/app/endpoints/query_v2.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import json
66
import logging
7-
from typing import Annotated, Any, cast
7+
from typing import Annotated, Any, Optional, cast
88

99
from fastapi import APIRouter, Depends, Request
1010
from llama_stack.apis.agents.openai_responses import (
@@ -74,7 +74,7 @@
7474

7575
def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-many-branches
7676
output_item: Any,
77-
) -> tuple[ToolCallSummary | None, ToolResultSummary | None]:
77+
) -> tuple[Optional[ToolCallSummary], Optional[ToolResultSummary]]:
7878
"""Translate applicable Responses API tool outputs into ``ToolCallSummary`` records.
7979
8080
The OpenAI ``response.output`` array may contain any ``OpenAIResponseOutput`` variant:
@@ -110,7 +110,7 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too-
110110
"status": getattr(output_item, "status", None),
111111
}
112112
results = getattr(output_item, "results", None)
113-
response_payload: Any | None = None
113+
response_payload: Optional[Any] = None
114114
if results is not None:
115115
# Store only the essential result metadata to avoid large payloads
116116
response_payload = {
@@ -294,7 +294,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
294294
model_id: str,
295295
query_request: QueryRequest,
296296
token: str,
297-
mcp_headers: dict[str, dict[str, str]] | None = None,
297+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
298298
*,
299299
provider_id: str = "",
300300
) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]:
@@ -505,7 +505,7 @@ def parse_referenced_documents_from_responses_api(
505505
"""
506506
documents: list[ReferencedDocument] = []
507507
# Use a set to track unique documents by (doc_url, doc_title) tuple
508-
seen_docs: set[tuple[str | None, str | None]] = set()
508+
seen_docs: set[tuple[Optional[str], Optional[str]]] = set()
509509

510510
if not response.output:
511511
return documents
@@ -535,7 +535,7 @@ def parse_referenced_documents_from_responses_api(
535535

536536
# If we have at least a filename or url
537537
if filename or doc_url:
538-
# Treat empty string as None for URL to satisfy AnyUrl | None
538+
# Treat empty string as None for URL to satisfy Optional[AnyUrl]
539539
final_url = doc_url if doc_url else None
540540
if (final_url, filename) not in seen_docs:
541541
documents.append(
@@ -692,15 +692,15 @@ def _increment_llm_call_metric(provider: str, model: str) -> None:
692692
logger.warning("Failed to update LLM call metric: %s", e)
693693

694694

695-
def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None:
695+
def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]:
696696
"""
697697
Convert vector store IDs to tools format for Responses API.
698698
699699
Args:
700700
vector_store_ids: List of vector store identifiers
701701
702702
Returns:
703-
list[dict[str, Any]] | None: List containing file_search tool configuration,
703+
Optional[list[dict[str, Any]]]: List containing file_search tool configuration,
704704
or None if no vector stores provided
705705
"""
706706
if not vector_store_ids:
@@ -717,8 +717,8 @@ def get_rag_tools(vector_store_ids: list[str]) -> list[dict[str, Any]] | None:
717717

718718
def get_mcp_tools(
719719
mcp_servers: list,
720-
token: str | None = None,
721-
mcp_headers: dict[str, dict[str, str]] | None = None,
720+
token: Optional[str] = None,
721+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
722722
) -> list[dict[str, Any]]:
723723
"""
724724
Convert MCP servers to tools format for Responses API.
@@ -762,8 +762,8 @@ async def prepare_tools_for_responses_api(
762762
query_request: QueryRequest,
763763
token: str,
764764
config: AppConfig,
765-
mcp_headers: dict[str, dict[str, str]] | None = None,
766-
) -> list[dict[str, Any]] | None:
765+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
766+
) -> Optional[list[dict[str, Any]]]:
767767
"""
768768
Prepare tools for Responses API including RAG and MCP tools.
769769
@@ -778,7 +778,7 @@ async def prepare_tools_for_responses_api(
778778
mcp_headers: Per-request headers for MCP servers
779779
780780
Returns:
781-
list[dict[str, Any]] | None: List of tool configurations for the
781+
Optional[list[dict[str, Any]]]: List of tool configurations for the
782782
Responses API, or None if no_tools is True or no tools are available
783783
"""
784784
if query_request.no_tools:

src/app/endpoints/streaming_query.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
import uuid
88
from collections.abc import Callable
99
from datetime import UTC, datetime
10-
from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast
10+
from typing import (
11+
Annotated,
12+
Any,
13+
AsyncGenerator,
14+
AsyncIterator,
15+
Iterator,
16+
Optional,
17+
cast,
18+
)
1119

1220
from fastapi import APIRouter, Depends, Request
1321
from fastapi.responses import StreamingResponse
@@ -231,7 +239,7 @@ def stream_build_event(
231239
chunk_id: int,
232240
metadata_map: dict,
233241
media_type: str = MEDIA_TYPE_JSON,
234-
conversation_id: str | None = None,
242+
conversation_id: Optional[str] = None,
235243
) -> Iterator[str]:
236244
"""Build a streaming event from a chunk response.
237245
@@ -384,7 +392,7 @@ async def stream_http_error(error: AbstractErrorResponse) -> AsyncGenerator[str,
384392
def _handle_turn_start_event(
385393
_chunk_id: int,
386394
media_type: str = MEDIA_TYPE_JSON,
387-
conversation_id: str | None = None,
395+
conversation_id: Optional[str] = None,
388396
) -> Iterator[str]:
389397
"""
390398
Yield turn start event.
@@ -734,7 +742,7 @@ async def response_generator(
734742
# Send start event at the beginning of the stream
735743
yield stream_start_event(context.conversation_id)
736744

737-
latest_turn: Any | None = None
745+
latest_turn: Optional[Any] = None
738746

739747
async for chunk in turn_response:
740748
if chunk.event is None:
@@ -850,7 +858,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
850858

851859
user_id, _user_name, _skip_userid_check, token = auth
852860

853-
user_conversation: UserConversation | None = None
861+
user_conversation: Optional[UserConversation] = None
854862
if query_request.conversation_id:
855863
user_conversation = validate_conversation_ownership(
856864
user_id=user_id, conversation_id=query_request.conversation_id
@@ -1001,7 +1009,7 @@ async def retrieve_response(
10011009
model_id: str,
10021010
query_request: QueryRequest,
10031011
token: str,
1004-
mcp_headers: dict[str, dict[str, str]] | None = None,
1012+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
10051013
) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]:
10061014
"""
10071015
Retrieve response from LLMs and agents.

src/app/endpoints/streaming_query_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Streaming query handler using Responses API (v2)."""
22

33
import logging
4-
from typing import Annotated, Any, AsyncIterator, cast
4+
from typing import Annotated, Any, AsyncIterator, Optional, cast
55

66
from fastapi import APIRouter, Depends, Request
77
from fastapi.responses import StreamingResponse
@@ -138,7 +138,7 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat
138138
start_event_emitted = False
139139

140140
# Track the latest response object from response.completed event
141-
latest_response_object: Any | None = None
141+
latest_response_object: Optional[Any] = None
142142

143143
logger.debug("Starting streaming response (Responses API) processing")
144144

@@ -372,7 +372,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
372372
model_id: str,
373373
query_request: QueryRequest,
374374
token: str,
375-
mcp_headers: dict[str, dict[str, str]] | None = None,
375+
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
376376
) -> tuple[AsyncIterator[OpenAIResponseObjectStream], str]:
377377
"""
378378
Retrieve response from LLMs and agents.
@@ -471,7 +471,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
471471

472472
async def create_violation_stream(
473473
message: str,
474-
shield_model: str | None = None,
474+
shield_model: Optional[str] = None,
475475
) -> AsyncIterator[OpenAIResponseObjectStream]:
476476
"""Generate a minimal streaming response for cases where input is blocked by a shield.
477477

src/utils/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ class ShieldModerationResult(BaseModel):
104104
"""Result of shield moderation check."""
105105

106106
blocked: bool
107-
message: str | None = None
108-
shield_model: str | None = None
107+
message: Optional[str] = None
108+
shield_model: Optional[str] = None
109109

110110

111111
class ToolCallSummary(BaseModel):

0 commit comments

Comments
 (0)