22from __future__ import annotations
33
44import json
5- from typing import TYPE_CHECKING , Any , Generic , Iterator , AsyncIterator
6- from typing_extensions import override
5+ from types import TracebackType
6+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , Iterator , AsyncIterator , cast
7+ from typing_extensions import Self , override
78
89import httpx
910
10- from ._types import ResponseT
1111from ._utils import is_mapping
1212from ._exceptions import APIError
1313
1414if TYPE_CHECKING :
1515 from ._client import OpenAI , AsyncOpenAI
1616
1717
18- class Stream (Generic [ResponseT ]):
18+ _T = TypeVar ("_T" )
19+
20+
21+ class Stream (Generic [_T ]):
1922 """Provides the core interface to iterate over a synchronous stream response."""
2023
2124 response : httpx .Response
2225
2326 def __init__ (
2427 self ,
2528 * ,
26- cast_to : type [ResponseT ],
29+ cast_to : type [_T ],
2730 response : httpx .Response ,
2831 client : OpenAI ,
2932 ) -> None :
@@ -33,18 +36,18 @@ def __init__(
3336 self ._decoder = SSEDecoder ()
3437 self ._iterator = self .__stream__ ()
3538
36- def __next__ (self ) -> ResponseT :
39+ def __next__ (self ) -> _T :
3740 return self ._iterator .__next__ ()
3841
39- def __iter__ (self ) -> Iterator [ResponseT ]:
42+ def __iter__ (self ) -> Iterator [_T ]:
4043 for item in self ._iterator :
4144 yield item
4245
4346 def _iter_events (self ) -> Iterator [ServerSentEvent ]:
4447 yield from self ._decoder .iter (self .response .iter_lines ())
4548
46- def __stream__ (self ) -> Iterator [ResponseT ]:
47- cast_to = self ._cast_to
49+ def __stream__ (self ) -> Iterator [_T ]:
50+ cast_to = cast ( Any , self ._cast_to )
4851 response = self .response
4952 process_data = self ._client ._process_response_data
5053 iterator = self ._iter_events ()
@@ -68,16 +71,35 @@ def __stream__(self) -> Iterator[ResponseT]:
6871 for _sse in iterator :
6972 ...
7073
74+ def __enter__ (self ) -> Self :
75+ return self
76+
77+ def __exit__ (
78+ self ,
79+ exc_type : type [BaseException ] | None ,
80+ exc : BaseException | None ,
81+ exc_tb : TracebackType | None ,
82+ ) -> None :
83+ self .close ()
84+
85+ def close (self ) -> None :
86+ """
87+ Close the response and release the connection.
88+
89+ Automatically called if the response body is read to completion.
90+ """
91+ self .response .close ()
7192
72- class AsyncStream (Generic [ResponseT ]):
93+
94+ class AsyncStream (Generic [_T ]):
7395 """Provides the core interface to iterate over an asynchronous stream response."""
7496
7597 response : httpx .Response
7698
7799 def __init__ (
78100 self ,
79101 * ,
80- cast_to : type [ResponseT ],
102+ cast_to : type [_T ],
81103 response : httpx .Response ,
82104 client : AsyncOpenAI ,
83105 ) -> None :
@@ -87,19 +109,19 @@ def __init__(
87109 self ._decoder = SSEDecoder ()
88110 self ._iterator = self .__stream__ ()
89111
90- async def __anext__ (self ) -> ResponseT :
112+ async def __anext__ (self ) -> _T :
91113 return await self ._iterator .__anext__ ()
92114
93- async def __aiter__ (self ) -> AsyncIterator [ResponseT ]:
115+ async def __aiter__ (self ) -> AsyncIterator [_T ]:
94116 async for item in self ._iterator :
95117 yield item
96118
97119 async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
98120 async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
99121 yield sse
100122
101- async def __stream__ (self ) -> AsyncIterator [ResponseT ]:
102- cast_to = self ._cast_to
123+ async def __stream__ (self ) -> AsyncIterator [_T ]:
124+ cast_to = cast ( Any , self ._cast_to )
103125 response = self .response
104126 process_data = self ._client ._process_response_data
105127 iterator = self ._iter_events ()
@@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
123145 async for _sse in iterator :
124146 ...
125147
148+ async def __aenter__ (self ) -> Self :
149+ return self
150+
151+ async def __aexit__ (
152+ self ,
153+ exc_type : type [BaseException ] | None ,
154+ exc : BaseException | None ,
155+ exc_tb : TracebackType | None ,
156+ ) -> None :
157+ await self .close ()
158+
159+ async def close (self ) -> None :
160+ """
161+ Close the response and release the connection.
162+
163+ Automatically called if the response body is read to completion.
164+ """
165+ await self .response .aclose ()
166+
126167
127168class ServerSentEvent :
128169 def __init__ (
0 commit comments