diff --git a/mcp_bridge/openai_clients/chatCompletion.py b/mcp_bridge/openai_clients/chatCompletion.py index c3c619e..54c7180 100644 --- a/mcp_bridge/openai_clients/chatCompletion.py +++ b/mcp_bridge/openai_clients/chatCompletion.py @@ -50,7 +50,7 @@ async def chat_completions( request.messages.append(msg) logger.debug(f"finish reason: {response.choices[0].finish_reason}") - if response.choices[0].finish_reason.value in ["stop", "length"]: + if response.choices[0].finish_reason == "length" or not response.choices[0].message.tool_calls: logger.debug("no tool calls found") return response diff --git a/mcp_bridge/openai_clients/streamChatCompletion.py b/mcp_bridge/openai_clients/streamChatCompletion.py index f518cd4..4cceb1b 100644 --- a/mcp_bridge/openai_clients/streamChatCompletion.py +++ b/mcp_bridge/openai_clients/streamChatCompletion.py @@ -52,8 +52,7 @@ async def chat_completions(request: CreateChatCompletionRequest, http_request: R # logger.debug(json_data) - last: Optional[CreateChatCompletionStreamResponse] = None # last message - + tool_call: bool = False tool_call_name: str = "" tool_call_json: str = "" should_forward: bool = True @@ -113,20 +112,16 @@ async def chat_completions(request: CreateChatCompletionRequest, http_request: R content = content if content is not None else "" response_content += content - # handle stop reasons - if len(parsed_data.choices) > 0 and parsed_data.choices[0].finish_reason is not None: - if parsed_data.choices[0].finish_reason.value in [ - "stop", - "length", - ]: - fully_done = True - else: - should_forward = False + # handle stopping for length + if len(parsed_data.choices) > 0 and parsed_data.choices[0].finish_reason == "length": + fully_done = True # this manages the incoming tool call schema # most of this is assertions to please mypy - if len(parsed_data.choices) > 0 and parsed_data.choices[0].delta.tool_calls is not None: + if len(parsed_data.choices) > 0 and parsed_data.choices[0].delta.tool_calls: should_forward = False + tool_call = True + assert ( parsed_data.choices[0].delta.tool_calls[0].function is not None ) @@ -149,15 +144,7 @@ async def chat_completions(request: CreateChatCompletionRequest, http_request: R logger.debug("forwarding message") yield SSEData.model_validate_json(sse.data).model_dump_json() - # save the last message - last = parsed_data - - # ideally we should check this properly - assert last is not None - if len(last.choices) > 0: - assert last.choices[0].finish_reason is not None - - if len(last.choices) > 0 and last.choices[0].finish_reason.value in ["stop", "length"]: + if fully_done or tool_call == False: logger.debug("no tool calls found") fully_done = True continue