11from collections .abc import AsyncGenerator
22from typing import Optional , Union
33
4+ from agents import Agent , ModelSettings , OpenAIChatCompletionsModel , Runner , set_tracing_disabled
45from openai import AsyncAzureOpenAI , AsyncOpenAI
56from openai .types .chat import ChatCompletionMessageParam
6- from pydantic_ai import Agent
7- from pydantic_ai .models .openai import OpenAIModel
8- from pydantic_ai .providers .openai import OpenAIProvider
9- from pydantic_ai .settings import ModelSettings
7+ from openai .types .responses import ResponseTextDeltaEvent
108
119from fastapi_app .api_models import (
1210 AIChatRoles ,
2119from fastapi_app .postgres_searcher import PostgresSearcher
2220from fastapi_app .rag_base import RAGChatBase
2321
22+ set_tracing_disabled (disabled = True )
23+
2424
2525class SimpleRAGChat (RAGChatBase ):
2626 def __init__ (
@@ -38,17 +38,17 @@ def __init__(
3838 self .model_for_thoughts = (
3939 {"model" : chat_model , "deployment" : chat_deployment } if chat_deployment else {"model" : chat_model }
4040 )
41- pydantic_chat_model = OpenAIModel (
42- chat_model if chat_deployment is None else chat_deployment ,
43- provider = OpenAIProvider (openai_client = openai_chat_client ),
41+ openai_agents_model = OpenAIChatCompletionsModel (
42+ model = chat_model if chat_deployment is None else chat_deployment , openai_client = openai_chat_client
4443 )
4544 self .answer_agent = Agent (
46- pydantic_chat_model ,
47- system_prompt = self .answer_prompt_template ,
45+ name = "Answerer" ,
46+ instructions = self .answer_prompt_template ,
47+ model = openai_agents_model ,
4848 model_settings = ModelSettings (
4949 temperature = self .chat_params .temperature ,
5050 max_tokens = self .chat_params .response_token_limit ,
51- ** ( {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {}) ,
51+ extra_body = {"seed" : self .chat_params .seed } if self .chat_params .seed is not None else {},
5252 ),
5353 )
5454
@@ -85,19 +85,21 @@ async def answer(
8585 items : list [ItemPublic ],
8686 earlier_thoughts : list [ThoughtStep ],
8787 ) -> RetrievalResponse :
88- response = await self .answer_agent .run (
89- user_prompt = self .prepare_rag_request (self .chat_params .original_user_query , items ),
90- message_history = self .chat_params .past_messages ,
88+ run_results = await Runner .run (
89+ self .answer_agent ,
90+ input = self .chat_params .past_messages
91+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
9192 )
93+
9294 return RetrievalResponse (
93- message = Message (content = str (response . output ), role = AIChatRoles .ASSISTANT ),
95+ message = Message (content = str (run_results . final_output ), role = AIChatRoles .ASSISTANT ),
9496 context = RAGContext (
9597 data_points = {item .id : item for item in items },
9698 thoughts = earlier_thoughts
9799 + [
98100 ThoughtStep (
99101 title = "Prompt to generate answer" ,
100- description = response . all_messages () ,
102+ description = run_results . input ,
101103 props = self .model_for_thoughts ,
102104 ),
103105 ],
@@ -109,24 +111,27 @@ async def answer_stream(
109111 items : list [ItemPublic ],
110112 earlier_thoughts : list [ThoughtStep ],
111113 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
112- async with self .answer_agent .run_stream (
113- self .prepare_rag_request (self .chat_params .original_user_query , items ),
114- message_history = self .chat_params .past_messages ,
115- ) as agent_stream_runner :
116- yield RetrievalResponseDelta (
117- context = RAGContext (
118- data_points = {item .id : item for item in items },
119- thoughts = earlier_thoughts
120- + [
121- ThoughtStep (
122- title = "Prompt to generate answer" ,
123- description = agent_stream_runner .all_messages (),
124- props = self .model_for_thoughts ,
125- ),
126- ],
127- ),
128- )
114+ run_results = Runner .run_streamed (
115+ self .answer_agent ,
116+ input = self .chat_params .past_messages
117+ + [{"content" : self .prepare_rag_request (self .chat_params .original_user_query , items ), "role" : "user" }],
118+ )
119+
120+ yield RetrievalResponseDelta (
121+ context = RAGContext (
122+ data_points = {item .id : item for item in items },
123+ thoughts = earlier_thoughts
124+ + [
125+ ThoughtStep (
126+ title = "Prompt to generate answer" ,
127+ description = run_results .input ,
128+ props = self .model_for_thoughts ,
129+ ),
130+ ],
131+ ),
132+ )
129133
130- async for message in agent_stream_runner .stream_text (delta = True , debounce_by = None ):
131- yield RetrievalResponseDelta (delta = Message (content = str (message ), role = AIChatRoles .ASSISTANT ))
132- return
134+ async for event in run_results .stream_events ():
135+ if event .type == "raw_response_event" and isinstance (event .data , ResponseTextDeltaEvent ):
136+ yield RetrievalResponseDelta (delta = Message (content = str (event .data .delta ), role = AIChatRoles .ASSISTANT ))
137+ return
0 commit comments