1- import pathlib
2- from abc import ABC , abstractmethod
31from collections .abc import AsyncGenerator
42from typing import Any
53
64from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
75from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
86from openai_messages_token_helper import build_messages , get_token_limit
9- from pydantic import BaseModel
107
118from fastapi_app .api_models import (
129 AIChatRoles ,
1815)
1916from fastapi_app .postgres_models import Item
2017from fastapi_app .postgres_searcher import PostgresSearcher
21-
22-
23- class ChatParams (BaseModel ):
24- top : int = 3
25- temperature : float = 0.3
26- response_token_limit : int = 1024
27- enable_text_search : bool
28- enable_vector_search : bool
29- original_user_query : str
30- past_messages : list [ChatCompletionMessageParam ]
31- prompt_template : str
32-
33-
34- class RAGChatBase (ABC ):
35- current_dir = pathlib .Path (__file__ ).parent
36- query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
37- answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
38-
39- def get_params (self , messages : list [ChatCompletionMessageParam ], overrides : dict [str , Any ]) -> ChatParams :
40- top : int = overrides .get ("top" , 3 )
41- temperature : float = overrides .get ("temperature" , 0.3 )
42- response_token_limit = 1024
43- prompt_template = overrides .get ("prompt_template" ) or self .answer_prompt_template
44-
45- enable_text_search = overrides .get ("retrieval_mode" ) in ["text" , "hybrid" , None ]
46- enable_vector_search = overrides .get ("retrieval_mode" ) in ["vectors" , "hybrid" , None ]
47-
48- original_user_query = messages [- 1 ]["content" ]
49- if not isinstance (original_user_query , str ):
50- raise ValueError ("The most recent message content must be a string." )
51- past_messages = messages [:- 1 ]
52-
53- return ChatParams (
54- top = top ,
55- temperature = temperature ,
56- response_token_limit = response_token_limit ,
57- prompt_template = prompt_template ,
58- enable_text_search = enable_text_search ,
59- enable_vector_search = enable_vector_search ,
60- original_user_query = original_user_query ,
61- past_messages = past_messages ,
62- )
63-
64- @abstractmethod
65- async def run (
66- self ,
67- messages : list [ChatCompletionMessageParam ],
68- overrides : dict [str , Any ] = {},
69- ) -> RetrievalResponse :
70- raise NotImplementedError
71-
72- @abstractmethod
73- async def retrieve_and_build_context (
74- self ,
75- chat_params : ChatParams ,
76- * args ,
77- ** kwargs ,
78- ) -> tuple [list [ChatCompletionMessageParam ], list [Item ]]:
79- raise NotImplementedError
80-
81- @abstractmethod
82- async def run_stream (
83- self ,
84- messages : list [ChatCompletionMessageParam ],
85- overrides : dict [str , Any ] = {},
86- ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
87- raise NotImplementedError
88- if False :
89- yield 0
18+ from fastapi_app .rag_base import ChatParams , RAGChatBase
9019
9120
9221class SimpleRAGChat (RAGChatBase ):
@@ -104,7 +33,7 @@ def __init__(
10433 self .chat_deployment = chat_deployment
10534 self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
10635
107- async def retreive_and_build_context (
36+ async def retrieve_and_build_context (
10837 self , chat_params : ChatParams
10938 ) -> tuple [list [ChatCompletionMessageParam ], list [Item ]]:
11039 """Retrieve relevant items from the database and build a context for the chat model."""
@@ -138,9 +67,7 @@ async def run(
13867 ) -> RetrievalResponse :
13968 chat_params = self .get_params (messages , overrides )
14069
141- # Retrieve relevant items from the database
142- # Generate a contextual and content specific answer using the search results and chat history
143- contextual_messages , results = await self .retreive_and_build_context (chat_params = chat_params )
70+ contextual_messages , results = await self .retrieve_and_build_context (chat_params = chat_params )
14471
14572 chat_completion_response : ChatCompletion = await self .openai_chat_client .chat .completions .create (
14673 # Azure OpenAI takes the deployment name as the model name
@@ -192,9 +119,7 @@ async def run_stream(
192119 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
193120 chat_params = self .get_params (messages , overrides )
194121
195- # Retrieve relevant items from the database
196- # Generate a contextual and content specific answer using the search results and chat history
197- contextual_messages , results = await self .retreive_and_build_context (chat_params = chat_params )
122+ contextual_messages , results = await self .retrieve_and_build_context (chat_params = chat_params )
198123
199124 chat_completion_async_stream : AsyncStream [
200125 ChatCompletionChunk
0 commit comments