1+ import os
12from collections .abc import AsyncGenerator
2- from typing import Any , Final , Optional , Union
3+ from typing import Optional , TypedDict , Union
34
45from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
5- from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
6- from openai_messages_token_helper import build_messages , get_token_limit
6+ from openai .types .chat import ChatCompletionChunk , ChatCompletionMessageParam
7+ from openai_messages_token_helper import get_token_limit
8+ from pydantic_ai import Agent , RunContext
9+ from pydantic_ai .models .openai import OpenAIModel
10+ from pydantic_ai .providers .openai import OpenAIProvider
11+ from pydantic_ai .settings import ModelSettings
712
813from fastapi_app .api_models import (
914 AIChatRoles ,
1520)
1621from fastapi_app .postgres_models import Item
1722from fastapi_app .postgres_searcher import PostgresSearcher
18- from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
1923from fastapi_app .rag_base import ChatParams , RAGChatBase
2024
25+ # Experiment #1: Annotated did not work!
26+ # Experiment #2: Function-level docstring, Inline docstrings next to attributes
27+ # Function -level docstring leads to XML like this: <summary>Search ...
28+ # Experiment #3: Move the docstrings below the attributes in triple-quoted strings - SUCCESS!!!
29+
30+
31+ class PriceFilter (TypedDict ):
32+ column : str = "price"
33+ """The column to filter on (always 'price' for this filter)"""
34+
35+ comparison_operator : str
36+ """The operator for price comparison ('>', '<', '>=', '<=', '=')"""
37+
38+ value : float
39+ """ The price value to compare against (e.g., 30.00) """
40+
41+
42+ class BrandFilter (TypedDict ):
43+ column : str = "brand"
44+ """The column to filter on (always 'brand' for this filter)"""
45+
46+ comparison_operator : str
47+ """The operator for brand comparison ('=' or '!=')"""
48+
49+ value : str
50+ """The brand name to compare against (e.g., 'AirStrider')"""
51+
2152
2253class AdvancedRAGChat (RAGChatBase ):
2354 def __init__ (
@@ -34,82 +65,64 @@ def __init__(
3465 self .chat_deployment = chat_deployment
3566 self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
3667
37- async def generate_search_query (
68+ async def search_database (
3869 self ,
39- original_user_query : str ,
40- past_messages : list [ChatCompletionMessageParam ],
41- query_response_token_limit : int ,
42- seed : Optional [int ] = None ,
43- ) -> tuple [list [ChatCompletionMessageParam ], Union [Any , str , None ], list ]:
44- """Generate an optimized keyword search query based on the chat history and the last question"""
45-
46- tools = build_search_function ()
47- tool_choice : Final = "auto"
48-
49- query_messages : list [ChatCompletionMessageParam ] = build_messages (
50- model = self .chat_model ,
51- system_prompt = self .query_prompt_template ,
52- few_shots = self .query_fewshots ,
53- new_user_content = original_user_query ,
54- past_messages = past_messages ,
55- max_tokens = self .chat_token_limit - query_response_token_limit ,
56- tools = tools ,
57- tool_choice = tool_choice ,
58- fallback_to_default = True ,
59- )
60-
61- chat_completion : ChatCompletion = await self .openai_chat_client .chat .completions .create (
62- messages = query_messages ,
63- # Azure OpenAI takes the deployment name as the model name
64- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
65- temperature = 0.0 , # Minimize creativity for search query generation
66- max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
67- n = 1 ,
68- tools = tools ,
69- tool_choice = tool_choice ,
70- seed = seed ,
71- )
72-
73- query_text , filters = extract_search_arguments (original_user_query , chat_completion )
74-
75- return query_messages , query_text , filters
76-
77- async def prepare_context (
78- self , chat_params : ChatParams
79- ) -> tuple [list [ChatCompletionMessageParam ], list [Item ], list [ThoughtStep ]]:
80- query_messages , query_text , filters = await self .generate_search_query (
81- original_user_query = chat_params .original_user_query ,
82- past_messages = chat_params .past_messages ,
83- query_response_token_limit = 500 ,
84- seed = chat_params .seed ,
85- )
86-
87- # Retrieve relevant rows from the database with the GPT optimized query
70+ ctx : RunContext [ChatParams ],
71+ search_query : str ,
72+ price_filter : Optional [PriceFilter ] = None ,
73+ brand_filter : Optional [BrandFilter ] = None ,
74+ ) -> list [str ]:
75+ """
76+ Search PostgreSQL database for relevant products based on user query
77+
78+ Args:
79+ search_query: Query string to use for full text search, e.g. 'red shoes'
80+ price_filter: Filter search results based on price of the product
81+ brand_filter: Filter search results based on brand of the product
82+
83+ Returns:
84+ List of formatted items that match the search query and filters
85+ """
86+ print (search_query , price_filter , brand_filter )
87+ # Only send non-None filters
88+ filters = []
89+ if price_filter :
90+ filters .append (price_filter )
91+ if brand_filter :
92+ filters .append (brand_filter )
8893 results = await self .searcher .search_and_embed (
89- query_text ,
90- top = chat_params .top ,
91- enable_vector_search = chat_params .enable_vector_search ,
92- enable_text_search = chat_params .enable_text_search ,
94+ search_query ,
95+ top = ctx . deps .top ,
96+ enable_vector_search = ctx . deps .enable_vector_search ,
97+ enable_text_search = ctx . deps .enable_text_search ,
9398 filters = filters ,
9499 )
100+ return [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
95101
96- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
97- content = "\n " .join (sources_content )
98-
99- # Generate a contextual and content specific answer using the search results and chat history
100- contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
101- model = self .chat_model ,
102- system_prompt = chat_params .prompt_template ,
103- new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
104- past_messages = chat_params .past_messages ,
105- max_tokens = self .chat_token_limit - chat_params .response_token_limit ,
106- fallback_to_default = True ,
102+ async def prepare_context (self , chat_params : ChatParams ) -> tuple [str , list [Item ], list [ThoughtStep ]]:
103+ model = OpenAIModel (
104+ os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ], provider = OpenAIProvider (openai_client = self .openai_chat_client )
105+ )
106+ agent = Agent (
107+ model ,
108+ model_settings = ModelSettings (temperature = 0.0 , max_tokens = 500 , seed = chat_params .seed ),
109+ system_prompt = self .query_prompt_template ,
110+ tools = [self .search_database ],
111+ output_type = list [str ],
112+ )
113+ # TODO: Provide few-shot examples
114+ results = await agent .run (
115+ f"Find search results for user query: { chat_params .original_user_query } " ,
116+ # message_history=chat_params.past_messages, # TODO
117+ deps = chat_params ,
107118 )
119+ if not isinstance (results , list ):
120+ raise ValueError ("Search results should be a list of strings" )
108121
109122 thoughts = [
110123 ThoughtStep (
111124 title = "Prompt to generate search arguments" ,
112- description = query_messages ,
125+ description = chat_params . past_messages , # TODO: update this
113126 props = (
114127 {"model" : self .chat_model , "deployment" : self .chat_deployment }
115128 if self .chat_deployment
@@ -118,50 +131,52 @@ async def prepare_context(
118131 ),
119132 ThoughtStep (
120133 title = "Search using generated search arguments" ,
121- description = query_text ,
134+ description = chat_params . original_user_query , # TODO:
122135 props = {
123136 "top" : chat_params .top ,
124137 "vector_search" : chat_params .enable_vector_search ,
125138 "text_search" : chat_params .enable_text_search ,
126- "filters" : filters ,
139+ "filters" : [], # TODO
127140 },
128141 ),
129142 ThoughtStep (
130143 title = "Search results" ,
131- description = [ result . to_dict () for result in results ],
144+ description = "" , # TODO
132145 ),
133146 ]
134- return contextual_messages , results , thoughts
147+ return results , thoughts
135148
136149 async def answer (
137150 self ,
138151 chat_params : ChatParams ,
139- contextual_messages : list [ChatCompletionMessageParam ],
140- results : list [Item ],
152+ results : list [str ],
141153 earlier_thoughts : list [ThoughtStep ],
142154 ) -> RetrievalResponse :
143- chat_completion_response : ChatCompletion = await self .openai_chat_client .chat .completions .create (
144- # Azure OpenAI takes the deployment name as the model name
145- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
146- messages = contextual_messages ,
147- temperature = chat_params .temperature ,
148- max_tokens = chat_params .response_token_limit ,
149- n = 1 ,
150- stream = False ,
151- seed = chat_params .seed ,
155+ agent = Agent (
156+ OpenAIModel (
157+ os .environ ["AZURE_OPENAI_CHAT_DEPLOYMENT" ],
158+ provider = OpenAIProvider (openai_client = self .openai_chat_client ),
159+ ),
160+ system_prompt = self .answer_prompt_template ,
161+ model_settings = ModelSettings (
162+ temperature = chat_params .temperature , max_tokens = chat_params .response_token_limit , seed = chat_params .seed
163+ ),
164+ )
165+
166+ response = await agent .run (
167+ user_prompt = chat_params .original_user_query + "Sources:\n " + "\n " .join (results ),
168+ message_history = chat_params .past_messages ,
152169 )
153170
154171 return RetrievalResponse (
155- message = Message (
156- content = str (chat_completion_response .choices [0 ].message .content ), role = AIChatRoles .ASSISTANT
157- ),
172+ message = Message (content = str (response .output ), role = AIChatRoles .ASSISTANT ),
158173 context = RAGContext (
159174 data_points = {item .id : item .to_dict () for item in results },
160175 thoughts = earlier_thoughts
161176 + [
162177 ThoughtStep (
163178 title = "Prompt to generate answer" ,
164- description = contextual_messages ,
179+ description = "" , # TODO: update
165180 props = (
166181 {"model" : self .chat_model , "deployment" : self .chat_deployment }
167182 if self .chat_deployment
0 commit comments