Skip to content

Commit 1c93410

Browse files
feat: Actualize query rewrite in search API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com> adding query expansion model to vector store config Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 7093978 commit 1c93410

14 files changed

+7333
-0
lines changed

src/llama_stack/core/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ class VectorStoresConfig(BaseModel):
376376
default=None,
377377
description="Default embedding model configuration for vector stores.",
378378
)
379+
default_query_expansion_model: QualifiedModel | None = Field(
380+
default=None,
381+
description="Default LLM model for query expansion/rewriting in vector search.",
382+
)
383+
query_expansion_prompt: str = Field(
384+
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
385+
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
386+
)
379387

380388

381389
class SafetyConfig(BaseModel):

src/llama_stack/core/routers/vector_io.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ async def query_chunks(
9999
) -> QueryChunksResponse:
100100
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
101101
provider = await self.routing_table.get_provider_impl(vector_store_id)
102+
103+
# Ensure params dict exists and add vector_stores_config for query rewriting
104+
if params is None:
105+
params = {}
106+
params["vector_stores_config"] = self.vector_stores_config
107+
102108
return await provider.query_chunks(vector_store_id, query, params)
103109

104110
# OpenAI Vector Stores API endpoints

src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ async def openai_search_vector_store(
611611
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
612612
"score_threshold": score_threshold,
613613
"mode": search_mode,
614+
"rewrite_query": rewrite_query,
614615
}
615616
# TODO: Add support for ranking_options.ranker
616617

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from numpy.typing import NDArray
1818
from pydantic import BaseModel
1919

20+
from llama_stack.core.datatypes import VectorStoresConfig
2021
from llama_stack.log import get_logger
2122
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
2223
from llama_stack.providers.utils.inference.prompt_adapter import (
@@ -34,6 +35,11 @@
3435
RAGDocument,
3536
VectorStore,
3637
)
38+
from llama_stack_api.inference import (
39+
OpenAIChatCompletionRequestWithExtraBody,
40+
OpenAIUserMessageParam,
41+
)
42+
from llama_stack_api.models import ModelType
3743

3844
log = get_logger(name=__name__, category="providers::utils")
3945

@@ -262,6 +268,7 @@ class VectorStoreWithIndex:
262268
vector_store: VectorStore
263269
index: EmbeddingIndex
264270
inference_api: Api.inference
271+
vector_stores_config: VectorStoresConfig | None = None
265272

266273
async def insert_chunks(
267274
self,
@@ -296,6 +303,11 @@ async def query_chunks(
296303
) -> QueryChunksResponse:
297304
if params is None:
298305
params = {}
306+
307+
# Extract configuration if provided by router
308+
if "vector_stores_config" in params:
309+
self.vector_stores_config = params["vector_stores_config"]
310+
299311
k = params.get("max_chunks", 3)
300312
mode = params.get("mode")
301313
score_threshold = params.get("score_threshold", 0.0)
@@ -318,6 +330,11 @@ async def query_chunks(
318330
reranker_params = {"impact_factor": k_value}
319331

320332
query_string = interleaved_content_as_str(query)
333+
334+
# Apply query rewriting if enabled
335+
if params.get("rewrite_query", False):
336+
query_string = await self._rewrite_query_for_search(query_string)
337+
321338
if mode == "keyword":
322339
return await self.index.query_keyword(query_string, k, score_threshold)
323340

@@ -333,3 +350,67 @@ async def query_chunks(
333350
)
334351
else:
335352
return await self.index.query_vector(query_vector, k, score_threshold)
353+
354+
async def _rewrite_query_for_search(self, query: str) -> str:
355+
"""Rewrite the user query to improve vector search performance.
356+
357+
:param query: The original user query
358+
:returns: The rewritten query optimized for vector search
359+
"""
360+
# Check if query expansion model is configured
361+
if not self.vector_stores_config or not self.vector_stores_config.default_query_expansion_model:
362+
raise ValueError("No default_query_expansion_model configured for query rewriting")
363+
364+
# Use the configured model
365+
expansion_model = self.vector_stores_config.default_query_expansion_model
366+
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
367+
368+
# Validate that the model is available and is an LLM
369+
try:
370+
models_response = await self.inference_api.routing_table.list_models()
371+
except Exception as e:
372+
raise RuntimeError(f"Failed to list available models for validation: {e}") from e
373+
374+
model_found = False
375+
for model in models_response.data:
376+
if model.identifier == chat_model:
377+
if model.model_type != ModelType.llm:
378+
raise ValueError(
379+
f"Configured query expansion model '{chat_model}' is not an LLM model "
380+
f"(found type: {model.model_type}). Query rewriting requires an LLM model."
381+
)
382+
model_found = True
383+
break
384+
385+
if not model_found:
386+
available_llm_models = [m.identifier for m in models_response.data if m.model_type == ModelType.llm]
387+
raise ValueError(
388+
f"Configured query expansion model '{chat_model}' is not available. "
389+
f"Available LLM models: {available_llm_models}"
390+
)
391+
392+
# Use the configured prompt (has a default value)
393+
rewrite_prompt = self.vector_stores_config.query_expansion_prompt.format(query=query)
394+
395+
chat_request = OpenAIChatCompletionRequestWithExtraBody(
396+
model=chat_model,
397+
messages=[
398+
OpenAIUserMessageParam(
399+
role="user",
400+
content=rewrite_prompt,
401+
)
402+
],
403+
max_tokens=100,
404+
)
405+
406+
try:
407+
response = await self.inference_api.openai_chat_completion(chat_request)
408+
except Exception as e:
409+
raise RuntimeError(f"Failed to generate rewritten query: {e}") from e
410+
411+
if response.choices and len(response.choices) > 0:
412+
rewritten_query = response.choices[0].message.content.strip()
413+
log.info(f"Query rewritten: '{query}' → '{rewritten_query}'")
414+
return rewritten_query
415+
else:
416+
raise RuntimeError("No response received from LLM model for query rewriting")

0 commit comments

Comments
 (0)