11import pathlib
22from abc import ABC , abstractmethod
33from collections .abc import AsyncGenerator
4- from typing import Any
54
65from openai .types .chat import ChatCompletionMessageParam
7- from pydantic import BaseModel
86
9- from fastapi_app .api_models import (
10- RetrievalResponse ,
11- RetrievalResponseDelta ,
12- )
7+ from fastapi_app .api_models import ChatRequestOverrides , RetrievalResponse , RetrievalResponseDelta
138from fastapi_app .postgres_models import Item
149
1510
16- class ChatParams (BaseModel ):
17- top : int = 3
18- temperature : float = 0.3
11+ class ChatParams (ChatRequestOverrides ):
12+ prompt_template : str
1913 response_token_limit : int = 1024
2014 enable_text_search : bool
2115 enable_vector_search : bool
2216 original_user_query : str
2317 past_messages : list [ChatCompletionMessageParam ]
24- prompt_template : str
2518
2619
2720class RAGChatBase (ABC ):
2821 current_dir = pathlib .Path (__file__ ).parent
2922 query_prompt_template = open (current_dir / "prompts/query.txt" ).read ()
3023 answer_prompt_template = open (current_dir / "prompts/answer.txt" ).read ()
3124
32- def get_params (self , messages : list [ChatCompletionMessageParam ], overrides : dict [str , Any ]) -> ChatParams :
33- top : int = overrides .get ("top" , 3 )
34- temperature : float = overrides .get ("temperature" , 0.3 )
25+ def get_params (self , messages : list [ChatCompletionMessageParam ], overrides : ChatRequestOverrides ) -> ChatParams :
3526 response_token_limit = 1024
36- prompt_template = overrides .get ( " prompt_template" ) or self .answer_prompt_template
27+ prompt_template = overrides .prompt_template or self .answer_prompt_template
3728
38- enable_text_search = overrides .get ( " retrieval_mode" ) in ["text" , "hybrid" , None ]
39- enable_vector_search = overrides .get ( " retrieval_mode" ) in ["vectors" , "hybrid" , None ]
29+ enable_text_search = overrides .retrieval_mode in ["text" , "hybrid" , None ]
30+ enable_vector_search = overrides .retrieval_mode in ["vectors" , "hybrid" , None ]
4031
4132 original_user_query = messages [- 1 ]["content" ]
4233 if not isinstance (original_user_query , str ):
4334 raise ValueError ("The most recent message content must be a string." )
4435 past_messages = messages [:- 1 ]
4536
4637 return ChatParams (
47- top = top ,
48- temperature = temperature ,
38+ top = overrides .top ,
39+ temperature = overrides .temperature ,
40+ retrieval_mode = overrides .retrieval_mode ,
41+ use_advanced_flow = overrides .use_advanced_flow ,
4942 response_token_limit = response_token_limit ,
5043 prompt_template = prompt_template ,
5144 enable_text_search = enable_text_search ,
@@ -67,15 +60,15 @@ async def retrieve_and_build_context(
6760 async def run (
6861 self ,
6962 messages : list [ChatCompletionMessageParam ],
70- overrides : dict [ str , Any ] = {} ,
63+ overrides : ChatRequestOverrides ,
7164 ) -> RetrievalResponse :
7265 raise NotImplementedError
7366
7467 @abstractmethod
7568 async def run_stream (
7669 self ,
7770 messages : list [ChatCompletionMessageParam ],
78- overrides : dict [ str , Any ] = {} ,
71+ overrides : ChatRequestOverrides ,
7972 ) -> AsyncGenerator [RetrievalResponseDelta , None ]:
8073 raise NotImplementedError
8174 if False :
0 commit comments