1212
1313from fastapi_app .api_models import (
1414 AIChatRoles ,
15+ ItemPublic ,
1516 Message ,
1617 RAGContext ,
1718 RetrievalResponse ,
@@ -50,6 +51,14 @@ class BrandFilter(TypedDict):
5051 """The brand name to compare against (e.g., 'AirStrider')"""
5152
5253
54+ class SearchResults (TypedDict ):
55+ items : list [ItemPublic ]
56+ """List of items that match the search query and filters"""
57+
58+ filters : list [Union [PriceFilter , BrandFilter ]]
59+ """List of filters applied to the search results"""
60+
61+
5362class AdvancedRAGChat (RAGChatBase ):
5463 def __init__ (
5564 self ,
@@ -71,7 +80,7 @@ async def search_database(
7180 search_query : str ,
7281 price_filter : Optional [PriceFilter ] = None ,
7382 brand_filter : Optional [BrandFilter ] = None ,
74- ) -> list [ str ] :
83+ ) -> SearchResults :
7584 """
7685 Search PostgreSQL database for relevant products based on user query
7786
@@ -83,7 +92,6 @@ async def search_database(
8392 Returns:
8493 List of formatted items that match the search query and filters
8594 """
86- print (search_query , price_filter , brand_filter )
8795 # Only send non-None filters
8896 filters = []
8997 if price_filter :
@@ -97,9 +105,9 @@ async def search_database(
97105 enable_text_search = ctx .deps .enable_text_search ,
98106 filters = filters ,
99107 )
100- return [ f"[ { (item .id ) } ]: { item . to_str_for_rag () } \n \n " for item in results ]
108+ return SearchResults ( items = [ ItemPublic . model_validate (item .to_dict ()) for item in results ], filters = filters )
101109
102- async def prepare_context (self , chat_params : ChatParams ) -> tuple [str , list [Item ], list [ThoughtStep ]]:
110+ async def prepare_context (self , chat_params : ChatParams ) -> tuple [list [ItemPublic ], list [ThoughtStep ]]:
103111 model = OpenAIModel (
104112 os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ], provider = OpenAIProvider (openai_client = self .openai_chat_client )
105113 )
@@ -108,17 +116,15 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
108116 model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = chat_params .seed ),
109117 system_prompt = self .query_prompt_template ,
110118 tools = [self .search_database ],
111- output_type = list [ str ] ,
119+ output_type = SearchResults ,
112120 )
113121 # TODO: Provide few-shot examples
114122 results = await agent .run (
115123 f"Find search results for user query: { chat_params .original_user_query } " ,
116124 # message_history=chat_params.past_messages, # TODO
117125 deps = chat_params ,
118126 )
119- if not isinstance (results , list ):
120- raise ValueError ("Search results should be a list of strings" )
121-
127+ items = results .output .items
122128 thoughts = [
123129 ThoughtStep (
124130 title = "Prompt to generate search arguments" ,
@@ -144,12 +150,12 @@ async def prepare_context(self, chat_params: ChatParams) -> tuple[str, list[Item
144150 description = "" , # TODO
145151 ),
146152 ]
147- return results , thoughts
153+ return items , thoughts
148154
149155 async def answer (
150156 self ,
151157 chat_params : ChatParams ,
152- results : list [str ],
158+ items : list [ItemPublic ],
153159 earlier_thoughts : list [ThoughtStep ],
154160 ) -> RetrievalResponse :
155161 agent = Agent (
@@ -163,15 +169,16 @@ async def answer(
163169 ),
164170 )
165171
172+ item_references = [item .to_str_for_rag () for item in items ]
166173 response = await agent .run (
167- user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (results ),
174+ user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (item_references ),
168175 message_history = chat_params .past_messages ,
169176 )
170177
171178 return RetrievalResponse (
172179 message = Message (content = str (response .output ), role = AIChatRoles .ASSISTANT ),
173180 context = RAGContext (
174- data_points = {item . id : item . to_dict () for item in results },
181+ data_points = {}, # TODO
175182 thoughts = earlier_thoughts
176183 + [
177184 ThoughtStep (
0 commit comments