|
2 | 2 | from collections.abc import AsyncGenerator |
3 | 3 | from typing import Optional, Union |
4 | 4 |
|
5 | | -from agents import Agent, ModelSettings, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled |
6 | | -from openai import AsyncAzureOpenAI, AsyncOpenAI |
7 | | -from openai.types.chat import ( |
8 | | - ChatCompletionMessageParam, |
9 | | -) |
10 | | -from openai.types.responses import ( |
11 | | - EasyInputMessageParam, |
12 | | - ResponseFunctionToolCallParam, |
13 | | - ResponseTextDeltaEvent, |
| 5 | +from agents import ( |
| 6 | + Agent, |
| 7 | + ModelSettings, |
| 8 | + OpenAIChatCompletionsModel, |
| 9 | + Runner, |
| 10 | + ToolCallOutputItem, |
| 11 | + function_tool, |
| 12 | + set_tracing_disabled, |
14 | 13 | ) |
15 | | -from openai.types.responses.response_input_item_param import FunctionCallOutput |
| 14 | +from openai import AsyncAzureOpenAI, AsyncOpenAI |
| 15 | +from openai.types.responses import EasyInputMessageParam, ResponseInputItemParam, ResponseTextDeltaEvent |
16 | 16 |
|
17 | 17 | from fastapi_app.api_models import ( |
18 | 18 | AIChatRoles, |
@@ -41,7 +41,7 @@ class AdvancedRAGChat(RAGChatBase): |
41 | 41 | def __init__( |
42 | 42 | self, |
43 | 43 | *, |
44 | | - messages: list[ChatCompletionMessageParam], |
| 44 | + messages: list[ResponseInputItemParam], |
45 | 45 | overrides: ChatRequestOverrides, |
46 | 46 | searcher: PostgresSearcher, |
47 | 47 | openai_chat_client: Union[AsyncOpenAI, AsyncAzureOpenAI], |
@@ -109,34 +109,17 @@ async def search_database( |
109 | 109 | ) |
110 | 110 |
|
111 | 111 | async def prepare_context(self) -> tuple[list[ItemPublic], list[ThoughtStep]]: |
112 | | - few_shots = json.loads(self.query_fewshots) |
113 | | - few_shot_inputs = [] |
114 | | - for few_shot in few_shots: |
115 | | - if few_shot["role"] == "user": |
116 | | - message = EasyInputMessageParam(role="user", content=few_shot["content"]) |
117 | | - elif few_shot["role"] == "assistant" and few_shot["tool_calls"] is not None: |
118 | | - message = ResponseFunctionToolCallParam( |
119 | | - id="madeup", |
120 | | - call_id=few_shot["tool_calls"][0]["id"], |
121 | | - name=few_shot["tool_calls"][0]["function"]["name"], |
122 | | - arguments=few_shot["tool_calls"][0]["function"]["arguments"], |
123 | | - type="function_call", |
124 | | - ) |
125 | | - elif few_shot["role"] == "tool" and few_shot["tool_call_id"] is not None: |
126 | | - message = FunctionCallOutput( |
127 | | - id="madeupoutput", |
128 | | - call_id=few_shot["tool_call_id"], |
129 | | - output=few_shot["content"], |
130 | | - type="function_call_output", |
131 | | - ) |
132 | | - few_shot_inputs.append(message) |
133 | | - |
| 112 | + few_shots: list[ResponseInputItemParam] = json.loads(self.query_fewshots) |
134 | 113 | user_query = f"Find search results for user query: {self.chat_params.original_user_query}" |
135 | 114 | new_user_message = EasyInputMessageParam(role="user", content=user_query) |
136 | | - all_messages = few_shot_inputs + self.chat_params.past_messages + [new_user_message] |
| 115 | + all_messages = few_shots + self.chat_params.past_messages + [new_user_message] |
137 | 116 |
|
138 | 117 | run_results = await Runner.run(self.search_agent, input=all_messages) |
139 | | - search_results = run_results.new_items[-1].output |
| 118 | + most_recent_response = run_results.new_items[-1] |
| 119 | + if isinstance(most_recent_response, ToolCallOutputItem): |
| 120 | + search_results = most_recent_response.output |
| 121 | + else: |
| 122 | + raise ValueError("Error retrieving search results, model did not call tool properly") |
140 | 123 |
|
141 | 124 | thoughts = [ |
142 | 125 | ThoughtStep( |
|
0 commit comments