11from collections .abc import AsyncGenerator
2- from typing import Any
2+ from typing import Any , Final
33
44from openai import AsyncAzureOpenAI , AsyncOpenAI , AsyncStream
55from openai .types .chat import ChatCompletion , ChatCompletionChunk , ChatCompletionMessageParam
@@ -38,12 +38,19 @@ async def generate_search_query(
3838 self , original_user_query : str , past_messages : list [ChatCompletionMessageParam ], query_response_token_limit : int
3939 ) -> tuple [list [ChatCompletionMessageParam ], Any | str | None , list ]:
4040 """Generate an optimized keyword search query based on the chat history and the last question"""
41+
42+ tools = build_search_function ()
43+ tool_choice : Final = "auto"
44+
4145 query_messages : list [ChatCompletionMessageParam ] = build_messages (
4246 model = self .chat_model ,
4347 system_prompt = self .query_prompt_template ,
48+ few_shots = self .query_fewshots ,
4449 new_user_content = original_user_query ,
4550 past_messages = past_messages ,
46- max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
51+ max_tokens = self .chat_token_limit - query_response_token_limit ,
52+ tools = tools ,
53+ tool_choice = tool_choice ,
4754 fallback_to_default = True ,
4855 )
4956
@@ -54,8 +61,8 @@ async def generate_search_query(
5461 temperature = 0.0 , # Minimize creativity for search query generation
5562 max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
5663 n = 1 ,
57- tools = build_search_function () ,
58- tool_choice = "auto" ,
64+ tools = tools ,
65+ tool_choice = tool_choice ,
5966 )
6067
6168 query_text , filters = extract_search_arguments (original_user_query , chat_completion )
0 commit comments