Skip to content
Open
26 changes: 26 additions & 0 deletions src/llama_stack/core/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT
from llama_stack_api import (
Api,
Benchmark,
Expand Down Expand Up @@ -365,6 +366,27 @@ class QualifiedModel(BaseModel):
model_id: str


class RewriteQueryParams(BaseModel):
"""Parameters for query rewriting/expansion."""

model: QualifiedModel | None = Field(
default=None,
description="LLM model for query rewriting/expansion in vector search.",
)
prompt: str = Field(
default=DEFAULT_QUERY_REWRITE_PROMPT,
description="Prompt template for query rewriting. Use {query} as placeholder for the original query.",
)
max_tokens: int = Field(
default=100,
description="Maximum number of tokens for query expansion responses.",
)
temperature: float = Field(
default=0.3,
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
)


class VectorStoresConfig(BaseModel):
"""Configuration for vector stores in the stack."""

Expand All @@ -376,6 +398,10 @@ class VectorStoresConfig(BaseModel):
default=None,
description="Default embedding model configuration for vector stores.",
)
rewrite_query_params: RewriteQueryParams | None = Field(
default=None,
description="Parameters for query rewriting/expansion. None disables query rewriting.",
)


class SafetyConfig(BaseModel):
Expand Down
67 changes: 53 additions & 14 deletions src/llama_stack/core/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import yaml

from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.datatypes import Provider, QualifiedModel, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
Expand Down Expand Up @@ -144,35 +144,68 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
if vector_stores_config is None:
return

default_embedding_model = vector_stores_config.default_embedding_model
if default_embedding_model is None:
return
# Validate default embedding model
if vector_stores_config.default_embedding_model is not None:
await _validate_embedding_model(vector_stores_config.default_embedding_model, impls)

# Validate default rewrite query model
if vector_stores_config.rewrite_query_params and vector_stores_config.rewrite_query_params.model:
await _validate_rewrite_query_model(vector_stores_config.rewrite_query_params.model, impls)


provider_id = default_embedding_model.provider_id
model_id = default_embedding_model.model_id
default_model_id = f"{provider_id}/{model_id}"
async def _validate_embedding_model(embedding_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that an embedding model exists and has required metadata."""
provider_id = embedding_model.provider_id
model_id = embedding_model.model_id
model_identifier = f"{provider_id}/{model_id}"

if Api.models not in impls:
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
raise ValueError(f"Models API is not available but vector_stores config requires model '{model_identifier}'")

models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}

default_model = models_list.get(default_model_id)
if default_model is None:
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
model = models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Embedding model '{model_identifier}' not found. Available embedding models: {list(models_list.keys())}"
)

embedding_dimension = default_model.metadata.get("embedding_dimension")
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is None:
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
raise ValueError(f"Embedding model '{model_identifier}' is missing 'embedding_dimension' in metadata")

try:
int(embedding_dimension)
except ValueError as err:
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err

logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
logger.debug(f"Validated embedding model: {model_identifier} (dimension: {embedding_dimension})")


async def _validate_rewrite_query_model(rewrite_query_model: QualifiedModel, impls: dict[Api, Any]) -> None:
"""Validate that a rewrite query model exists and is accessible."""
provider_id = rewrite_query_model.provider_id
model_id = rewrite_query_model.model_id
model_identifier = f"{provider_id}/{model_id}"

if Api.models not in impls:
raise ValueError(
f"Models API is not available but vector_stores config requires rewrite query model '{model_identifier}'"
)

models_impl = impls[Api.models]
response = await models_impl.list_models()
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}

model = llm_models_list.get(model_identifier)
if model is None:
raise ValueError(
f"Rewrite query model '{model_identifier}' not found. Available LLM models: {list(llm_models_list.keys())}"
)

logger.debug(f"Validated rewrite query model: {model_identifier}")


async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
Expand Down Expand Up @@ -437,6 +470,12 @@ async def initialize(self):
await refresh_registry_once(impls)
await validate_vector_stores_config(self.run_config.vector_stores, impls)
await validate_safety_config(self.run_config.safety, impls)

# Set global query expansion configuration from stack config
Copy link
Collaborator Author

@franciscojavierarceo franciscojavierarceo Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

importing here to avoid importing numpy

from llama_stack.providers.utils.memory.rewrite_query_config import set_default_rewrite_query_config

set_default_rewrite_query_config(self.run_config.vector_stores)

self.impls = impls

def create_registry_refresh_task(self):
Expand Down
4 changes: 4 additions & 0 deletions src/llama_stack/providers/utils/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .constants import DEFAULT_QUERY_REWRITE_PROMPT

__all__ = ["DEFAULT_QUERY_REWRITE_PROMPT"]
8 changes: 8 additions & 0 deletions src/llama_stack/providers/utils/memory/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Default prompt template for query rewriting in vector search
DEFAULT_QUERY_REWRITE_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,14 @@ async def openai_search_vector_store(
if ranking_options and ranking_options.score_threshold is not None
else 0.0
)

params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
"rewrite_query": rewrite_query,
}

# TODO: Add support for ranking_options.ranker

response = await self.query_chunks(
Expand Down
38 changes: 38 additions & 0 deletions src/llama_stack/providers/utils/memory/rewrite_query_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT

# Global configuration for query rewriting - set during stack startup
_DEFAULT_REWRITE_QUERY_MODEL: QualifiedModel | None = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS: int = 100
_DEFAULT_REWRITE_QUERY_TEMPERATURE: float = 0.3
_REWRITE_QUERY_PROMPT_OVERRIDE: str | None = None


def set_default_rewrite_query_config(vector_stores_config: VectorStoresConfig | None):
"""Set the global default query rewriting configuration from stack config."""
global \
_DEFAULT_REWRITE_QUERY_MODEL, \
_REWRITE_QUERY_PROMPT_OVERRIDE, \
_DEFAULT_REWRITE_QUERY_MAX_TOKENS, \
_DEFAULT_REWRITE_QUERY_TEMPERATURE
if vector_stores_config and vector_stores_config.rewrite_query_params:
params = vector_stores_config.rewrite_query_params
_DEFAULT_REWRITE_QUERY_MODEL = params.model
# Only set override if user provided a custom prompt different from default
if params.prompt != DEFAULT_QUERY_REWRITE_PROMPT:
_REWRITE_QUERY_PROMPT_OVERRIDE = params.prompt
else:
_REWRITE_QUERY_PROMPT_OVERRIDE = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = params.max_tokens
_DEFAULT_REWRITE_QUERY_TEMPERATURE = params.temperature
else:
_DEFAULT_REWRITE_QUERY_MODEL = None
_REWRITE_QUERY_PROMPT_OVERRIDE = None
_DEFAULT_REWRITE_QUERY_MAX_TOKENS = 100
_DEFAULT_REWRITE_QUERY_TEMPERATURE = 0.3
44 changes: 44 additions & 0 deletions src/llama_stack/providers/utils/memory/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Chunk,
ChunkMetadata,
InterleavedContent,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
QueryChunksResponse,
RAGDocument,
Expand All @@ -37,6 +38,9 @@

log = get_logger(name=__name__, category="providers::utils")

from llama_stack.providers.utils.memory import rewrite_query_config
from llama_stack.providers.utils.memory.constants import DEFAULT_QUERY_REWRITE_PROMPT


class ChunkForDeletion(BaseModel):
"""Information needed to delete a chunk from a vector store.
Expand Down Expand Up @@ -289,13 +293,48 @@ async def insert_chunks(
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
await self.index.add_chunks(chunks, embeddings)

async def _rewrite_query_for_file_search(self, query: str) -> str:
"""Rewrite a search query using the globally configured LLM model for better retrieval results."""
if not rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL:
raise ValueError(
"Query rewriting requested but not configured. Please configure rewrite_query_params.model in vector_stores config."
)

model_id = f"{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.provider_id}/{rewrite_query_config._DEFAULT_REWRITE_QUERY_MODEL.model_id}"

# Use custom prompt from config if provided, otherwise use built-in default
# Users only need to configure the model - prompt is automatic with optional override
if rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE:
# Custom prompt from config - format if it contains {query} placeholder
prompt = (
rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE.format(query=query)
if "{query}" in rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
else rewrite_query_config._REWRITE_QUERY_PROMPT_OVERRIDE
)
else:
# Use built-in default prompt and format with query
prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query=query)

request = OpenAIChatCompletionRequestWithExtraBody(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=rewrite_query_config._DEFAULT_REWRITE_QUERY_MAX_TOKENS,
temperature=rewrite_query_config._DEFAULT_REWRITE_QUERY_TEMPERATURE,
)

response = await self.inference_api.openai_chat_completion(request)
rewritten_query = response.choices[0].message.content.strip()
log.debug(f"Query rewritten: '{query}' → '{rewritten_query}'")
return rewritten_query

async def query_chunks(
self,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
if params is None:
params = {}

k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
Expand All @@ -318,6 +357,11 @@ async def query_chunks(
reranker_params = {"impact_factor": k_value}

query_string = interleaved_content_as_str(query)

# Apply query rewriting if enabled and model is configured
if params.get("rewrite_query", False):
query_string = await self._rewrite_query_for_file_search(query_string)

if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)

Expand Down
Loading