1- import pathlib
21from collections .abc import AsyncGenerator
3- from typing import (
4- Any ,
5- )
2+ from typing import Any
63
7- from openai import AsyncAzureOpenAI , AsyncOpenAI
8- from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
4+ from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
5+ from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
96from openai_messages_token_helper import build_messages , get_token_limit
107
11- from .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
12- from .postgres_searcher import PostgresSearcher
13- from .query_rewriter import build_search_function , extract_search_arguments
8+ from fastapi_app .api_models import (
9+ AIChatRoles ,
10+ Message ,
11+ RAGContext ,
12+ RetrievalResponse ,
13+ RetrievalResponseDelta ,
14+ ThoughtStep ,
15+ )
16+ from fastapi_app .postgres_models import Item
17+ from fastapi_app .postgres_searcher import PostgresSearcher
18+ from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
19+ from fastapi_app .rag_base import ChatParams , RAGChatBase
1420
1521
16- class AdvancedRAGChat :
22+ class AdvancedRAGChat ( RAGChatBase ) :
1723 def __init__ (
1824 self ,
1925 * ,
@@ -27,24 +33,11 @@ def __init__(
2733 self .chat_model = chat_model
2834 self .chat_deployment = chat_deployment
2935 self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
30- current_dir = pathlib .Path (__file__ ).parent
31- self .query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
32- self .answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
33-
34- async def run (
35- self , messages : list [ChatCompletionMessageParam ], overrides : dict [str , Any ] = {}
36- ) -> RetrievalResponse | AsyncGenerator [dict [str , Any ], None ]:
37- text_search = overrides .get ("retrieval_mode" ) in ["text" , "hybrid" , None ]
38- vector_search = overrides .get ("retrieval_mode" ) in ["vectors" , "hybrid" , None ]
39- top = overrides .get ("top" , 3 )
40-
41- original_user_query = messages [- 1 ]["content" ]
42- if not isinstance (original_user_query , str ):
43- raise ValueError ("The most recent message content must be a string." )
44- past_messages = messages [:- 1 ]
45-
46- # Generate an optimized keyword search query based on the chat history and the last question
47- query_response_token_limit = 500
36+
37+ async def generate_search_query (
38+ self , original_user_query : str , past_messages : list [ChatCompletionMessageParam ], query_response_token_limit : int
39+ ) -> tuple [list [ChatCompletionMessageParam ], Any | str | None , list ]:
40+ """Generate an optimized keyword search query based on the chat history and the last question"""
4841 query_messages : list [ChatCompletionMessageParam ] = build_messages (
4942 model = self .chat_model ,
5043 system_prompt = self .query_prompt_template ,
@@ -67,68 +60,128 @@ async def run(
6760
6861 query_text , filters = extract_search_arguments (original_user_query , chat_completion )
6962
63+ return query_messages , query_text , filters
64+
65+ async def prepare_context (
66+ self , chat_params : ChatParams
67+ ) -> tuple [list [ChatCompletionMessageParam ], list [Item ], list [ThoughtStep ]]:
68+ query_messages , query_text , filters = await self .generate_search_query (
69+ original_user_query = chat_params .original_user_query ,
70+ past_messages = chat_params .past_messages ,
71+ query_response_token_limit = 500 ,
72+ )
73+
7074 # Retrieve relevant items from the database with the GPT optimized query
7175 results = await self .searcher .search_and_embed (
7276 query_text ,
73- top = top ,
74- enable_vector_search = vector_search ,
75- enable_text_search = text_search ,
77+ top = chat_params . top ,
78+ enable_vector_search = chat_params . enable_vector_search ,
79+ enable_text_search = chat_params . enable_text_search ,
7680 filters = filters ,
7781 )
7882
7983 sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
8084 content = "\n " .join (sources_content )
8185
8286 # Generate a contextual and content specific answer using the search results and chat history
83- response_token_limit = 1024
8487 contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
8588 model = self .chat_model ,
86- system_prompt = overrides . get ( " prompt_template" ) or self . answer_prompt_template ,
87- new_user_content = original_user_query + "\n \n Sources:\n " + content ,
88- past_messages = past_messages ,
89- max_tokens = self .chat_token_limit - response_token_limit ,
89+ system_prompt = chat_params . prompt_template ,
90+ new_user_content = chat_params . original_user_query + "\n \n Sources:\n " + content ,
91+ past_messages = chat_params . past_messages ,
92+ max_tokens = self .chat_token_limit - chat_params . response_token_limit ,
9093 fallback_to_default = True ,
9194 )
9295
96+ thoughts = [
97+ ThoughtStep (
98+ title = "Prompt to generate search arguments" ,
99+ description = [str (message ) for message in query_messages ],
100+ props = (
101+ {"model" : self .chat_model , "deployment" : self .chat_deployment }
102+ if self .chat_deployment
103+ else {"model" : self .chat_model }
104+ ),
105+ ),
106+ ThoughtStep (
107+ title = "Search using generated search arguments" ,
108+ description = query_text ,
109+ props = {
110+ "top" : chat_params .top ,
111+ "vector_search" : chat_params .enable_vector_search ,
112+ "text_search" : chat_params .enable_text_search ,
113+ "filters" : filters ,
114+ },
115+ ),
116+ ThoughtStep (
117+ title = "Search results" ,
118+ description = [result .to_dict () for result in results ],
119+ ),
120+ ]
121+ return contextual_messages , results , thoughts
122+
123+ async def answer (
124+ self ,
125+ chat_params : ChatParams ,
126+ contextual_messages : list [ChatCompletionMessageParam ],
127+ results : list [Item ],
128+ earlier_thoughts : list [ThoughtStep ],
129+ ) -> RetrievalResponse :
93130 chat_completion_response : ChatCompletion = await self .openai_chat_client .chat .completions .create (
94131 # Azure OpenAI takes the deployment name as the model name
95132 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
96133 messages = contextual_messages ,
97- temperature = overrides . get ( " temperature" , 0.3 ) ,
98- max_tokens = response_token_limit ,
134+ temperature = chat_params . temperature ,
135+ max_tokens = chat_params . response_token_limit ,
99136 n = 1 ,
100137 stream = False ,
101138 )
102- first_choice_message = chat_completion_response .choices [0 ].message
103139
104140 return RetrievalResponse (
105- message = Message (content = str (first_choice_message .content ), role = first_choice_message .role ),
141+ message = Message (
142+ content = str (chat_completion_response .choices [0 ].message .content ), role = AIChatRoles .ASSISTANT
143+ ),
106144 context = RAGContext (
107145 data_points = {item .id : item .to_dict () for item in results },
108- thoughts = [
146+ thoughts = earlier_thoughts
147+ + [
109148 ThoughtStep (
110- title = "Prompt to generate search arguments " ,
111- description = [str (message ) for message in query_messages ],
149+ title = "Prompt to generate answer " ,
150+ description = [str (message ) for message in contextual_messages ],
112151 props = (
113152 {"model" : self .chat_model , "deployment" : self .chat_deployment }
114153 if self .chat_deployment
115154 else {"model" : self .chat_model }
116155 ),
117156 ),
118- ThoughtStep (
119- title = "Search using generated search arguments" ,
120- description = query_text ,
121- props = {
122- "top" : top ,
123- "vector_search" : vector_search ,
124- "text_search" : text_search ,
125- "filters" : filters ,
126- },
127- ),
128- ThoughtStep (
129- title = "Search results" ,
130- description = [result .to_dict () for result in results ],
131- ),
157+ ],
158+ ),
159+ )
160+
161+ async def answer_stream (
162+ self ,
163+ chat_params : ChatParams ,
164+ contextual_messages : list [ChatCompletionMessageParam ],
165+ results : list [Item ],
166+ earlier_thoughts : list [ThoughtStep ],
167+ ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
168+ chat_completion_async_stream : AsyncStream [
169+ ChatCompletionChunk
170+ ] = await self .openai_chat_client .chat .completions .create (
171+ # Azure OpenAI takes the deployment name as the model name
172+ model = self .chat_deployment if self .chat_deployment else self .chat_model ,
173+ messages = contextual_messages ,
174+ temperature = chat_params .temperature ,
175+ max_tokens = chat_params .response_token_limit ,
176+ n = 1 ,
177+ stream = True ,
178+ )
179+
180+ yield RetrievalResponseDelta (
181+ context = RAGContext (
182+ data_points = {item .id : item .to_dict () for item in results },
183+ thoughts = earlier_thoughts
184+ + [
132185 ThoughtStep (
133186 title = "Prompt to generate answer" ,
134187 description = [str (message ) for message in contextual_messages ],
@@ -141,3 +194,11 @@ async def run(
141194 ],
142195 ),
143196 )
197+
198+ async for response_chunk in chat_completion_async_stream :
199+ # first response has empty choices and last response has empty content
200+ if response_chunk .choices and response_chunk .choices [0 ].delta .content :
201+ yield RetrievalResponseDelta (
202+ delta = Message (content = str (response_chunk .choices [0 ].delta .content ), role = AIChatRoles .ASSISTANT )
203+ )
204+ return
0 commit comments