1717from numpy .typing import NDArray
1818from pydantic import BaseModel
1919
20+ from llama_stack .core .datatypes import VectorStoresConfig
2021from llama_stack .log import get_logger
2122from llama_stack .models .llama .llama3 .tokenizer import Tokenizer
2223from llama_stack .providers .utils .inference .prompt_adapter import (
3435 RAGDocument ,
3536 VectorStore ,
3637)
38+ from llama_stack_api .inference import (
39+ OpenAIChatCompletionRequestWithExtraBody ,
40+ OpenAIUserMessageParam ,
41+ )
42+ from llama_stack_api .models import ModelType
3743
3844log = get_logger (name = __name__ , category = "providers::utils" )
3945
@@ -262,6 +268,7 @@ class VectorStoreWithIndex:
262268 vector_store : VectorStore
263269 index : EmbeddingIndex
264270 inference_api : Api .inference
271+ vector_stores_config : VectorStoresConfig | None = None
265272
266273 async def insert_chunks (
267274 self ,
@@ -296,6 +303,11 @@ async def query_chunks(
296303 ) -> QueryChunksResponse :
297304 if params is None :
298305 params = {}
306+
307+ # Extract configuration if provided by router
308+ if "vector_stores_config" in params :
309+ self .vector_stores_config = params ["vector_stores_config" ]
310+
299311 k = params .get ("max_chunks" , 3 )
300312 mode = params .get ("mode" )
301313 score_threshold = params .get ("score_threshold" , 0.0 )
@@ -318,6 +330,11 @@ async def query_chunks(
318330 reranker_params = {"impact_factor" : k_value }
319331
320332 query_string = interleaved_content_as_str (query )
333+
334+ # Apply query rewriting if enabled
335+ if params .get ("rewrite_query" , False ):
336+ query_string = await self ._rewrite_query_for_search (query_string )
337+
321338 if mode == "keyword" :
322339 return await self .index .query_keyword (query_string , k , score_threshold )
323340
@@ -333,3 +350,67 @@ async def query_chunks(
333350 )
334351 else :
335352 return await self .index .query_vector (query_vector , k , score_threshold )
353+
354+ async def _rewrite_query_for_search (self , query : str ) -> str :
355+ """Rewrite the user query to improve vector search performance.
356+
357+ :param query: The original user query
358+ :returns: The rewritten query optimized for vector search
359+ """
360+ # Check if query expansion model is configured
361+ if not self .vector_stores_config or not self .vector_stores_config .default_query_expansion_model :
362+ raise ValueError ("No default_query_expansion_model configured for query rewriting" )
363+
364+ # Use the configured model
365+ expansion_model = self .vector_stores_config .default_query_expansion_model
366+ chat_model = f"{ expansion_model .provider_id } /{ expansion_model .model_id } "
367+
368+ # Validate that the model is available and is an LLM
369+ try :
370+ models_response = await self .inference_api .routing_table .list_models ()
371+ except Exception as e :
372+ raise RuntimeError (f"Failed to list available models for validation: { e } " ) from e
373+
374+ model_found = False
375+ for model in models_response .data :
376+ if model .identifier == chat_model :
377+ if model .model_type != ModelType .llm :
378+ raise ValueError (
379+ f"Configured query expansion model '{ chat_model } ' is not an LLM model "
380+ f"(found type: { model .model_type } ). Query rewriting requires an LLM model."
381+ )
382+ model_found = True
383+ break
384+
385+ if not model_found :
386+ available_llm_models = [m .identifier for m in models_response .data if m .model_type == ModelType .llm ]
387+ raise ValueError (
388+ f"Configured query expansion model '{ chat_model } ' is not available. "
389+ f"Available LLM models: { available_llm_models } "
390+ )
391+
392+ # Use the configured prompt (has a default value)
393+ rewrite_prompt = self .vector_stores_config .query_expansion_prompt .format (query = query )
394+
395+ chat_request = OpenAIChatCompletionRequestWithExtraBody (
396+ model = chat_model ,
397+ messages = [
398+ OpenAIUserMessageParam (
399+ role = "user" ,
400+ content = rewrite_prompt ,
401+ )
402+ ],
403+ max_tokens = 100 ,
404+ )
405+
406+ try :
407+ response = await self .inference_api .openai_chat_completion (chat_request )
408+ except Exception as e :
409+ raise RuntimeError (f"Failed to generate rewritten query: { e } " ) from e
410+
411+ if response .choices and len (response .choices ) > 0 :
412+ rewritten_query = response .choices [0 ].message .content .strip ()
413+ log .info (f"Query rewritten: '{ query } ' → '{ rewritten_query } '" )
414+ return rewritten_query
415+ else :
416+ raise RuntimeError ("No response received from LLM model for query rewriting" )
0 commit comments