1+ import json
12from collections .abc import AsyncGenerator
23from typing import Optional , Union
34
5+ from agents import Agent , ModelSettings , OpenAIChatCompletionsModel , Runner , function_tool , set_tracing_disabled
46from openai import AsyncAzureOpenAI , AsyncOpenAI
5- from openai .types .chat import ChatCompletionMessageParam
6- from pydantic_ai import Agent , RunContext
7- from pydantic_ai .messages import ModelMessagesTypeAdapter
8- from pydantic_ai .models .openai import OpenAIModel
9- from pydantic_ai .providers .openai import OpenAIProvider
10- from pydantic_ai .settings import ModelSettings
7+ from openai .types .chat import (
8+ ChatCompletionMessageParam ,
9+ )
10+ from openai .types .responses import (
11+ EasyInputMessageParam ,
12+ ResponseFunctionToolCallParam ,
13+ ResponseTextDeltaEvent ,
14+ )
15+ from openai .types .responses .response_input_item_param import FunctionCallOutput
1116
1217from fastapi_app .api_models import (
1318 AIChatRoles ,
2429 ThoughtStep ,
2530)
2631from fastapi_app .postgres_searcher import PostgresSearcher
27- from fastapi_app .rag_base import ChatParams , RAGChatBase
32+ from fastapi_app .rag_base import RAGChatBase
33+
34+ set_tracing_disabled (disabled = True )
2835
2936
3037class AdvancedRAGChat (RAGChatBase ):
@@ -46,34 +53,29 @@ def __init__(
4653 self .model_for_thoughts = (
4754 {"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
4855 )
49- pydantic_chat_model = OpenAIModel (
50- chat_model if chat_deployment is None else chat_deployment ,
51- provider = OpenAIProvider (openai_client = openai_chat_client ),
56+ openai_agents_model = OpenAIChatCompletionsModel (
57+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
5258 )
53- self .search_agent = Agent [ChatParams , SearchResults ](
54- pydantic_chat_model ,
55- model_settings = ModelSettings (
56- temperature = 0.0 ,
57- max_tokens = 500 ,
58- ** ({"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}),
59- ),
60- system_prompt = self .query_prompt_template ,
61- tools = [self .search_database ],
62- output_type = SearchResults ,
59+ self .search_agent = Agent (
60+ name = "Searcher" ,
61+ instructions = self .query_prompt_template ,
62+ tools = [function_tool (self .search_database )],
63+ tool_use_behavior = "stop_on_first_tool" ,
64+ model = openai_agents_model ,
6365 )
6466 self .answer_agent = Agent (
65- pydantic_chat_model ,
66- system_prompt = self .answer_prompt_template ,
67+ name = "Answerer" ,
68+ instructions = self .answer_prompt_template ,
69+ model = openai_agents_model ,
6770 model_settings = ModelSettings (
6871 temperature = self .chat_params .temperature ,
6972 max_tokens = self .chat_params .response_token_limit ,
70- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
73+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
7174 ),
7275 )
7376
7477 async def search_database (
7578 self ,
76- ctx : RunContext [ChatParams ],
7779 search_query : str ,
7880 price_filter : Optional [PriceFilter ] = None ,
7981 brand_filter : Optional [BrandFilter ] = None ,
@@ -97,66 +99,88 @@ async def search_database(
9799 filters .append (brand_filter )
98100 results = await self .searcher .search_and_embed (
99101 search_query ,
100- top = ctx . deps .top ,
101- enable_vector_search = ctx . deps .enable_vector_search ,
102- enable_text_search = ctx . deps .enable_text_search ,
102+ top = self . chat_params .top ,
103+ enable_vector_search = self . chat_params .enable_vector_search ,
104+ enable_text_search = self . chat_params .enable_text_search ,
103105 filters = filters ,
104106 )
105107 return SearchResults (
106108 query = search_query , items = [ItemPublic .model_validate (item .to_dict ()) for item in results ], filters = filters
107109 )
108110
109111 async def prepare_context (self ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
110- few_shots = ModelMessagesTypeAdapter .validate_json (self .query_fewshots )
112+ few_shots = json .loads (self .query_fewshots )
113+ few_shot_inputs = []
114+ for few_shot in few_shots :
115+ if few_shot ["role" ] == "user" :
116+ message = EasyInputMessageParam (role = "user" , content = few_shot ["content" ])
117+ elif few_shot ["role" ] == "assistant" and few_shot ["tool_calls" ] is not None :
118+ message = ResponseFunctionToolCallParam (
119+ id = "madeup" ,
120+ call_id = few_shot ["tool_calls" ][0 ]["id" ],
121+ name = few_shot ["tool_calls" ][0 ]["function" ]["name" ],
122+ arguments = few_shot ["tool_calls" ][0 ]["function" ]["arguments" ],
123+ type = "function_call" ,
124+ )
125+ elif few_shot ["role" ] == "tool" and few_shot ["tool_call_id" ] is not None :
126+ message = FunctionCallOutput (
127+ id = "madeupoutput" ,
128+ call_id = few_shot ["tool_call_id" ],
129+ output = few_shot ["content" ],
130+ type = "function_call_output" ,
131+ )
132+ few_shot_inputs .append (message )
133+
111134 user_query = f"Find search results for user query: { self .chat_params .original_user_query } "
112- results = await self . search_agent . run (
113- user_query ,
114- message_history = few_shots + self . chat_params . past_messages ,
115- deps = self .chat_params ,
116- )
117- items = results . output . items
135+ new_user_message = EasyInputMessageParam ( role = "user" , content = user_query )
136+ all_messages = few_shot_inputs + self . chat_params . past_messages + [ new_user_message ]
137+
138+ run_results = await Runner . run ( self .search_agent , input = all_messages )
139+ search_results = run_results . new_items [ - 1 ]. output
140+
118141 thoughts = [
119142 ThoughtStep (
120143 title = "Prompt to generate search arguments" ,
121- description = results . all_messages () ,
144+ description = run_results . input ,
122145 props = self .model_for_thoughts ,
123146 ),
124147 ThoughtStep (
125148 title = "Search using generated search arguments" ,
126- description = results . output .query ,
149+ description = search_results .query ,
127150 props = {
128151 "top" : self .chat_params .top ,
129152 "vector_search" : self .chat_params .enable_vector_search ,
130153 "text_search" : self .chat_params .enable_text_search ,
131- "filters" : results . output .filters ,
154+ "filters" : search_results .filters ,
132155 },
133156 ),
134157 ThoughtStep (
135158 title = "Search results" ,
136- description = items ,
159+ description = search_results . items ,
137160 ),
138161 ]
139- return items , thoughts
162+ return search_results . items , thoughts
140163
141164 async def answer (
142165 self ,
143166 items : list [ItemPublic ],
144167 earlier_thoughts : list [ThoughtStep ],
145168 ) -> RetrievalResponse :
146- response = await self .answer_agent .run (
147- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
148- message_history = self .chat_params .past_messages ,
169+ run_results = await Runner .run (
170+ self .answer_agent ,
171+ input = self .chat_params .past_messages
172+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
149173 )
150174
151175 return RetrievalResponse (
152- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
176+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
153177 context = RAGContext (
154178 data_points = {item .id : item for item in items },
155179 thoughts = earlier_thoughts
156180 + [
157181 ThoughtStep (
158182 title = "Prompt to generate answer" ,
159- description = response . all_messages () ,
183+ description = run_results . input ,
160184 props = self .model_for_thoughts ,
161185 ),
162186 ],
@@ -168,24 +192,27 @@ async def answer_stream(
168192 items : list [ItemPublic ],
169193 earlier_thoughts : list [ThoughtStep ],
170194 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
171- async with self .answer_agent .run_stream (
172- self .prepare_rag_request (self .chat_params .original_user_query , items ),
173- message_history = self .chat_params .past_messages ,
174- ) as agent_stream_runner :
175- yield RetrievalResponseDelta (
176- context = RAGContext (
177- data_points = {item .id : item for item in items },
178- thoughts = earlier_thoughts
179- + [
180- ThoughtStep (
181- title = "Prompt to generate answer" ,
182- description = agent_stream_runner .all_messages (),
183- props = self .model_for_thoughts ,
184- ),
185- ],
186- ),
187- )
188-
189- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
190- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
191- return
195+ run_results = Runner .run_streamed (
196+ self .answer_agent ,
197+ input = self .chat_params .past_messages
198+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
199+ )
200+
201+ yield RetrievalResponseDelta (
202+ context = RAGContext (
203+ data_points = {item .id : item for item in items },
204+ thoughts = earlier_thoughts
205+ + [
206+ ThoughtStep (
207+ title = "Prompt to generate answer" ,
208+ description = run_results .input ,
209+ props = self .model_for_thoughts ,
210+ ),
211+ ],
212+ ),
213+ )
214+
215+ async for event in run_results .stream_events ():
216+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
217+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
218+ return
0 commit comments