88from fastapi import FastAPI , Request
99from fastapi .responses import JSONResponse , StreamingResponse
1010from pydantic import BaseModel
11+ from asyncio import CancelledError , Task
1112
1213
1314class ChatCompletionRequest (BaseModel ):
@@ -35,14 +36,27 @@ def __init__(self, interpreter):
3536 # Setup routes
3637 self .app .post ("/chat/completions" )(self .chat_completion )
3738
39+ # Add a field to track the current request task
40+ self ._current_request : Optional [Task ] = None
41+
3842 async def chat_completion (self , request : Request ):
3943 """Main chat completion endpoint"""
44+ # Cancel any existing request
45+ if self ._current_request and not self ._current_request .done ():
46+ self ._current_request .cancel ()
47+ try :
48+ await self ._current_request
49+ except CancelledError :
50+ pass
51+
4052 body = await request .json ()
53+ if self .interpreter .debug :
54+ print ("Request body:" , body )
4155 try :
4256 req = ChatCompletionRequest (** body )
4357 except Exception as e :
44- print ("Validation error:" , str (e )) # Debug print
45- print ("Request body:" , body ) # Print the request body
58+ print ("Validation error:" , str (e ))
59+ print ("Request body:" , body )
4660 raise
4761
4862 # Filter out system message
@@ -75,18 +89,6 @@ async def _stream_response(self):
7589 delta ["function_call" ] = choice .delta .function_call
7690 if choice .delta .tool_calls is not None :
7791 pass
78- # Convert tool_calls to dict representation
79- # delta["tool_calls"] = [
80- # {
81- # "index": tool_call.index,
82- # "id": tool_call.id,
83- # "type": tool_call.type,
84- # "function": {
85- # "name": tool_call.function.name,
86- # "arguments": tool_call.function.arguments
87- # }
88- # } for tool_call in choice.delta.tool_calls
89- # ]
9092
9193 choices .append (
9294 {
@@ -108,11 +110,16 @@ async def _stream_response(self):
108110 data ["system_fingerprint" ] = chunk .system_fingerprint
109111
110112 yield f"data: { json .dumps (data )} \n \n "
111- except asyncio .CancelledError :
112- # Set stop flag when stream is cancelled
113- self .interpreter ._stop_flag = True
113+
114+ except CancelledError :
115+ # Handle cancellation gracefully
116+ print ("Request cancelled - cleaning up..." )
117+
114118 raise
119+ except Exception as e :
120+ print (f"Error in stream: { str (e )} " )
115121 finally :
122+ # Always send DONE message and cleanup
116123 yield "data: [DONE]\n \n "
117124
118125 def run (self ):
0 commit comments