66import threading
77import time
88import warnings
9- from contextlib import asynccontextmanager
109from json import JSONDecodeError
1110from typing import (
11+ AsyncContextManager ,
1212 AsyncGenerator ,
13- AsyncIterator ,
1413 Callable ,
1514 Dict ,
1615 Iterator ,
@@ -368,8 +367,9 @@ async def arequest(
368367 request_id : Optional [str ] = None ,
369368 request_timeout : Optional [Union [float , Tuple [float , float ]]] = None ,
370369 ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
371- ctx = aiohttp_session ()
370+ ctx = AioHTTPSession ()
372371 session = await ctx .__aenter__ ()
372+ result = None
373373 try :
374374 result = await self .arequest_raw (
375375 method .lower (),
@@ -383,6 +383,9 @@ async def arequest(
383383 )
384384 resp , got_stream = await self ._interpret_async_response (result , stream )
385385 except Exception :
386+ # Close the request before exiting session context.
387+ if result is not None :
388+ result .release ()
386389 await ctx .__aexit__ (None , None , None )
387390 raise
388391 if got_stream :
@@ -393,10 +396,15 @@ async def wrap_resp():
393396 async for r in resp :
394397 yield r
395398 finally :
399+ # Close the request before exiting session context. Important to do it here
400+ # as if stream is not fully exhausted, we need to close the request nevertheless.
401+ result .release ()
396402 await ctx .__aexit__ (None , None , None )
397403
398404 return wrap_resp (), got_stream , self .api_key
399405 else :
406+ # Close the request before exiting session context.
407+ result .release ()
400408 await ctx .__aexit__ (None , None , None )
401409 return resp , got_stream , self .api_key
402410
@@ -770,11 +778,22 @@ def _interpret_response_line(
770778 return resp
771779
772780
773- @asynccontextmanager
774- async def aiohttp_session () -> AsyncIterator [aiohttp .ClientSession ]:
775- user_set_session = openai .aiosession .get ()
776- if user_set_session :
777- yield user_set_session
778- else :
779- async with aiohttp .ClientSession () as session :
780- yield session
781+ class AioHTTPSession (AsyncContextManager ):
782+ def __init__ (self ):
783+ self ._session = None
784+ self ._should_close_session = False
785+
786+ async def __aenter__ (self ):
787+ self ._session = openai .aiosession .get ()
788+ if self ._session is None :
789+ self ._session = await aiohttp .ClientSession ().__aenter__ ()
790+ self ._should_close_session = True
791+
792+ return self ._session
793+
794+ async def __aexit__ (self , exc_type , exc_value , traceback ):
795+ if self ._session is None :
796+ raise RuntimeError ("Session is not initialized" )
797+
798+ if self ._should_close_session :
799+ await self ._session .__aexit__ (exc_type , exc_value , traceback )
0 commit comments