Skip to content

Commit 17f0029

Browse files
fix(stainless): handle [DONE] SSE terminator in streaming responses
1 parent f1a093b commit 17f0029

File tree

5 files changed

+112
-4
lines changed

5 files changed

+112
-4
lines changed

.stats.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 108
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/llamastack%2Fllama-stack-client-ef864e1fd05e4fda60d0c67ea0f8d49db03efcfaddb0b28c561962a695510f6e.yml
33
openapi_spec_hash: fd0140251c983c3788c9da642426f1ba
4-
config_hash: 6aa61d4143c3e3df785972c0287d1370
4+
config_hash: ef1f9b33e203c71cfc10d91890c1ed2d

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,50 @@ async def main() -> None:
128128
asyncio.run(main())
129129
```
130130

131+
## Streaming responses
132+
133+
We provide support for streaming responses using Server Side Events (SSE).
134+
135+
```python
136+
from llama_stack_client import LlamaStackClient
137+
138+
client = LlamaStackClient()
139+
140+
stream = client.chat.completions.create(
141+
messages=[
142+
{
143+
"content": "string",
144+
"role": "user",
145+
}
146+
],
147+
model="model",
148+
stream=True,
149+
)
150+
for completion in stream:
151+
print(completion.id)
152+
```
153+
154+
The async client uses the exact same interface.
155+
156+
```python
157+
from llama_stack_client import AsyncLlamaStackClient
158+
159+
client = AsyncLlamaStackClient()
160+
161+
stream = await client.chat.completions.create(
162+
messages=[
163+
{
164+
"content": "string",
165+
"role": "user",
166+
}
167+
],
168+
model="model",
169+
stream=True,
170+
)
171+
async for completion in stream:
172+
print(completion.id)
173+
```
174+
131175
## Using types
132176

133177
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like:

src/llama_stack_client/_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def __init__(
158158
_strict_response_validation=_strict_response_validation,
159159
)
160160

161+
self._default_stream_cls = Stream
162+
161163
@cached_property
162164
def toolgroups(self) -> ToolgroupsResource:
163165
from .resources.toolgroups import ToolgroupsResource
@@ -515,6 +517,8 @@ def __init__(
515517
_strict_response_validation=_strict_response_validation,
516518
)
517519

520+
self._default_stream_cls = AsyncStream
521+
518522
@cached_property
519523
def toolgroups(self) -> AsyncToolgroupsResource:
520524
from .resources.toolgroups import AsyncToolgroupsResource

src/llama_stack_client/_streaming.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
import httpx
1717

18-
from ._utils import extract_type_var_from_base
18+
from ._utils import is_mapping, extract_type_var_from_base
19+
from ._exceptions import APIError
1920

2021
if TYPE_CHECKING:
2122
from ._client import LlamaStackClient, AsyncLlamaStackClient
@@ -65,7 +66,25 @@ def __stream__(self) -> Iterator[_T]:
6566

6667
try:
6768
for sse in iterator:
68-
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
69+
if sse.data.startswith("[DONE]"):
70+
break
71+
72+
data = sse.json()
73+
if is_mapping(data) and data.get("error"):
74+
message = None
75+
error = data.get("error")
76+
if is_mapping(error):
77+
message = error.get("message")
78+
if not message or not isinstance(message, str):
79+
message = "An error occurred during streaming"
80+
81+
raise APIError(
82+
message=message,
83+
request=self.response.request,
84+
body=data["error"],
85+
)
86+
87+
yield process_data(data=data, cast_to=cast_to, response=response)
6988
finally:
7089
# Ensure the response is closed even if the consumer doesn't read all data
7190
response.close()
@@ -131,7 +150,25 @@ async def __stream__(self) -> AsyncIterator[_T]:
131150

132151
try:
133152
async for sse in iterator:
134-
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
153+
if sse.data.startswith("[DONE]"):
154+
break
155+
156+
data = sse.json()
157+
if is_mapping(data) and data.get("error"):
158+
message = None
159+
error = data.get("error")
160+
if is_mapping(error):
161+
message = error.get("message")
162+
if not message or not isinstance(message, str):
163+
message = "An error occurred during streaming"
164+
165+
raise APIError(
166+
message=message,
167+
request=self.response.request,
168+
body=data["error"],
169+
)
170+
171+
yield process_data(data=data, cast_to=cast_to, response=response)
135172
finally:
136173
# Ensure the response is closed even if the consumer doesn't read all data
137174
await response.aclose()

tests/test_client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from llama_stack_client._types import Omit
3030
from llama_stack_client._utils import asyncify
3131
from llama_stack_client._models import BaseModel, FinalRequestOptions
32+
from llama_stack_client._streaming import Stream, AsyncStream
3233
from llama_stack_client._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
3334
from llama_stack_client._base_client import (
3435
DEFAULT_TIMEOUT,
@@ -773,6 +774,17 @@ def test_client_max_retries_validation(self) -> None:
773774
with pytest.raises(TypeError, match=r"max_retries cannot be None"):
774775
LlamaStackClient(base_url=base_url, _strict_response_validation=True, max_retries=cast(Any, None))
775776

777+
@pytest.mark.respx(base_url=base_url)
778+
def test_default_stream_cls(self, respx_mock: MockRouter, client: LlamaStackClient) -> None:
779+
class Model(BaseModel):
780+
name: str
781+
782+
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
783+
784+
stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
785+
assert isinstance(stream, Stream)
786+
stream.response.close()
787+
776788
@pytest.mark.respx(base_url=base_url)
777789
def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
778790
class Model(BaseModel):
@@ -1685,6 +1697,17 @@ async def test_client_max_retries_validation(self) -> None:
16851697
with pytest.raises(TypeError, match=r"max_retries cannot be None"):
16861698
AsyncLlamaStackClient(base_url=base_url, _strict_response_validation=True, max_retries=cast(Any, None))
16871699

1700+
@pytest.mark.respx(base_url=base_url)
1701+
async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncLlamaStackClient) -> None:
1702+
class Model(BaseModel):
1703+
name: str
1704+
1705+
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1706+
1707+
stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1708+
assert isinstance(stream, AsyncStream)
1709+
await stream.response.aclose()
1710+
16881711
@pytest.mark.respx(base_url=base_url)
16891712
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
16901713
class Model(BaseModel):

0 commit comments

Comments
 (0)