44import logging
55import time
66from functools import wraps
7- from typing import Any , Dict , Iterator , Optional , Union
7+ from typing import Any , AsyncIterator , Optional , Union
88
99import openai
1010
@@ -56,7 +56,7 @@ async def traced_create_func(*args, **kwargs):
5656 stream = kwargs .get ("stream" , False )
5757
5858 if stream :
59- return await handle_async_streaming_create (
59+ return handle_async_streaming_create (
6060 * args ,
6161 ** kwargs ,
6262 create_func = create_func ,
@@ -81,7 +81,7 @@ async def handle_async_streaming_create(
8181 is_azure_openai : bool = False ,
8282 inference_id : Optional [str ] = None ,
8383 ** kwargs ,
84- ) -> Iterator [Any ]:
84+ ) -> AsyncIterator [Any ]:
8585 """Handles the create method when streaming is enabled.
8686
8787 Parameters
@@ -95,25 +95,12 @@ async def handle_async_streaming_create(
9595
9696 Returns
9797 -------
98- Iterator [Any]
98+ AsyncIterator [Any]
9999 A generator that yields the chunks of the completion.
100100 """
101101 chunks = await create_func (* args , ** kwargs )
102- return await stream_async_chunks (
103- chunks = chunks ,
104- kwargs = kwargs ,
105- inference_id = inference_id ,
106- is_azure_openai = is_azure_openai ,
107- )
108102
109-
110- async def stream_async_chunks (
111- chunks : Iterator [Any ],
112- kwargs : Dict [str , any ],
113- is_azure_openai : bool = False ,
114- inference_id : Optional [str ] = None ,
115- ):
116- """Streams the chunks of the completion and traces the completion."""
103+ # Create and return a new async generator that processes chunks
117104 collected_output_data = []
118105 collected_function_call = {
119106 "name" : "" ,
@@ -143,9 +130,9 @@ async def stream_async_chunks(
143130 if delta .function_call .name :
144131 collected_function_call ["name" ] += delta .function_call .name
145132 if delta .function_call .arguments :
146- collected_function_call ["arguments" ] += (
147- delta . function_call . arguments
148- )
133+ collected_function_call [
134+ " arguments"
135+ ] += delta . function_call . arguments
149136 elif delta .tool_calls :
150137 if delta .tool_calls [0 ].function .name :
151138 collected_function_call ["name" ] += delta .tool_calls [0 ].function .name
@@ -155,6 +142,7 @@ async def stream_async_chunks(
155142 ].function .arguments
156143
157144 yield chunk
145+
158146 end_time = time .time ()
159147 latency = (end_time - start_time ) * 1000
160148 # pylint: disable=broad-except
0 commit comments