Skip to content

Commit c0447dd

Browse files
feat: Actualize query rewrite in search API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 7093978 commit c0447dd

12 files changed

+7300
-0
lines changed

src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ async def openai_search_vector_store(
611611
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
612612
"score_threshold": score_threshold,
613613
"mode": search_mode,
614+
"rewrite_query": rewrite_query,
614615
}
615616
# TODO: Add support for ranking_options.ranker
616617

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
RAGDocument,
3535
VectorStore,
3636
)
37+
from llama_stack_api.inference import (
38+
OpenAIChatCompletionRequestWithExtraBody,
39+
OpenAIUserMessageParam,
40+
)
41+
from llama_stack_api.models import ModelType
3742

3843
log = get_logger(name=__name__, category="providers::utils")
3944

@@ -318,6 +323,11 @@ async def query_chunks(
318323
reranker_params = {"impact_factor": k_value}
319324

320325
query_string = interleaved_content_as_str(query)
326+
327+
# Apply query rewriting if enabled
328+
if params.get("rewrite_query", False):
329+
query_string = await self._rewrite_query_for_search(query_string)
330+
321331
if mode == "keyword":
322332
return await self.index.query_keyword(query_string, k, score_threshold)
323333

@@ -333,3 +343,55 @@ async def query_chunks(
333343
)
334344
else:
335345
return await self.index.query_vector(query_vector, k, score_threshold)
346+
347+
async def _rewrite_query_for_search(self, query: str) -> str:
348+
"""Rewrite the user query to improve vector search performance.
349+
350+
:param query: The original user query
351+
:returns: The rewritten query optimized for vector search
352+
"""
353+
# Get available models and find a suitable chat model
354+
try:
355+
models_response = await self.inference_api.routing_table.list_models()
356+
except Exception as e:
357+
raise RuntimeError(f"Failed to list available models for query rewriting: {e}") from e
358+
359+
chat_model = None
360+
# Look for an LLM model (for chat completion)
361+
for model in models_response.data:
362+
if model.model_type == ModelType.llm:
363+
chat_model = model.identifier
364+
break
365+
366+
# If no suitable model found, raise an error
367+
if not chat_model:
368+
raise ValueError("No LLM model available for query rewriting")
369+
370+
rewrite_prompt = f"""Rewrite this search query to improve vector search results by expanding it with relevant synonyms and related terms while maintaining the original intent:
371+
372+
{query}
373+
374+
Rewritten query:"""
375+
376+
chat_request = OpenAIChatCompletionRequestWithExtraBody(
377+
model=chat_model,
378+
messages=[
379+
OpenAIUserMessageParam(
380+
role="user",
381+
content=rewrite_prompt,
382+
)
383+
],
384+
max_tokens=100,
385+
)
386+
387+
try:
388+
response = await self.inference_api.openai_chat_completion(chat_request)
389+
except Exception as e:
390+
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
391+
392+
if response.choices and len(response.choices) > 0:
393+
rewritten_query = response.choices[0].message.content.strip()
394+
log.info(f"Query rewritten: '{query}' → '{rewritten_query}'")
395+
return rewritten_query
396+
else:
397+
raise RuntimeError("No response received from LLM model for query rewriting")

0 commit comments

Comments
 (0)