Skip to content

Commit 5349c33

Browse files
feat: Actualize query rewrite in search API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 7093978 commit 5349c33

12 files changed

+7323
-0
lines changed

src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ async def openai_search_vector_store(
611611
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
612612
"score_threshold": score_threshold,
613613
"mode": search_mode,
614+
"rewrite_query": rewrite_query,
614615
}
615616
# TODO: Add support for ranking_options.ranker
616617

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
RAGDocument,
3535
VectorStore,
3636
)
37+
from llama_stack_api.inference import (
38+
OpenAIChatCompletionRequestWithExtraBody,
39+
OpenAIUserMessageParam,
40+
)
41+
from llama_stack_api.models import ModelType
3742

3843
log = get_logger(name=__name__, category="providers::utils")
3944

@@ -318,6 +323,11 @@ async def query_chunks(
318323
reranker_params = {"impact_factor": k_value}
319324

320325
query_string = interleaved_content_as_str(query)
326+
327+
# Apply query rewriting if enabled
328+
if params.get("rewrite_query", False):
329+
query_string = await self._rewrite_query_for_search(query_string)
330+
321331
if mode == "keyword":
322332
return await self.index.query_keyword(query_string, k, score_threshold)
323333

@@ -333,3 +343,78 @@ async def query_chunks(
333343
)
334344
else:
335345
return await self.index.query_vector(query_vector, k, score_threshold)
346+
347+
async def _rewrite_query_for_search(self, query: str) -> str:
348+
"""Rewrite the user query to improve vector search performance.
349+
350+
:param query: The original user query
351+
:returns: The rewritten query optimized for vector search
352+
"""
353+
# Get available models and find a suitable chat model
354+
try:
355+
models_response = await self.inference_api.routing_table.list_models()
356+
except Exception as e:
357+
raise RuntimeError(f"Failed to list available models for query rewriting: {e}") from e
358+
359+
chat_model = None
360+
# Look for an LLM model (for chat completion)
361+
# Prefer local or non-cloud providers to avoid credential issues
362+
llm_models = [m for m in models_response.data if m.model_type == ModelType.llm]
363+
364+
# Filter out models that are known to be embedding models (misclassified as LLM)
365+
embedding_model_patterns = ["minilm", "embed", "embedding", "nomic-embed"]
366+
llm_models = [
367+
m for m in llm_models if not any(pattern in m.identifier.lower() for pattern in embedding_model_patterns)
368+
]
369+
370+
# Priority order: ollama (local), then OpenAI, then others
371+
provider_priority = ["ollama", "openai", "gemini", "bedrock"]
372+
373+
for provider in provider_priority:
374+
for model in llm_models:
375+
model_id = model.identifier.lower()
376+
if provider == "ollama" and "ollama/" in model_id:
377+
chat_model = model.identifier
378+
break
379+
elif provider in model_id:
380+
chat_model = model.identifier
381+
break
382+
if chat_model:
383+
break
384+
385+
# Fallback: use first available LLM model if no preferred provider found
386+
if not chat_model and llm_models:
387+
chat_model = llm_models[0].identifier
388+
389+
# If no suitable model found, raise an error
390+
if not chat_model:
391+
raise ValueError("No LLM model available for query rewriting")
392+
393+
rewrite_prompt = f"""Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:
394+
395+
{query}
396+
397+
Improved query:"""
398+
399+
chat_request = OpenAIChatCompletionRequestWithExtraBody(
400+
model=chat_model,
401+
messages=[
402+
OpenAIUserMessageParam(
403+
role="user",
404+
content=rewrite_prompt,
405+
)
406+
],
407+
max_tokens=100,
408+
)
409+
410+
try:
411+
response = await self.inference_api.openai_chat_completion(chat_request)
412+
except Exception as e:
413+
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
414+
415+
if response.choices and len(response.choices) > 0:
416+
rewritten_query = response.choices[0].message.content.strip()
417+
log.info(f"Query rewritten: '{query}' → '{rewritten_query}'")
418+
return rewritten_query
419+
else:
420+
raise RuntimeError("No response received from LLM model for query rewriting")

0 commit comments

Comments
 (0)