Skip to content

Commit 376f16e

Browse files
added quey expnasion model to extra_body
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 21f2085 commit 376f16e

File tree

4 files changed

+130
-12
lines changed

4 files changed

+130
-12
lines changed

src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,11 @@ async def openai_create_vector_store(
379379
f"Using embedding config from extra_body: model='{embedding_model}', dimension={embedding_dimension}"
380380
)
381381

382+
# Extract query expansion model from extra_body if provided
383+
query_expansion_model = extra_body.get("query_expansion_model")
384+
if query_expansion_model:
385+
logger.debug(f"Using per-store query expansion model: {query_expansion_model}")
386+
382387
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
383388
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
384389
# Derive the canonical vector_store_id (allow override, else generate)
@@ -402,6 +407,7 @@ async def openai_create_vector_store(
402407
provider_id=provider_id,
403408
provider_resource_id=vector_store_id,
404409
vector_store_name=params.name,
410+
query_expansion_model=query_expansion_model,
405411
)
406412
await self.register_vector_store(vector_store)
407413

@@ -607,12 +613,14 @@ async def openai_search_vector_store(
607613
if ranking_options and ranking_options.score_threshold is not None
608614
else 0.0
609615
)
616+
610617
params = {
611618
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
612619
"score_threshold": score_threshold,
613620
"mode": search_mode,
614621
"rewrite_query": rewrite_query,
615622
}
623+
616624
# Add vector_stores_config if available (for query rewriting)
617625
if hasattr(self, "vector_stores_config"):
618626
params["vector_stores_config"] = self.vector_stores_config

src/llama_stack/providers/utils/memory/vector_store.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from numpy.typing import NDArray
1818
from pydantic import BaseModel
1919

20-
from llama_stack.core.datatypes import VectorStoresConfig
20+
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
2121
from llama_stack.log import get_logger
2222
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
2323
from llama_stack.providers.utils.inference.prompt_adapter import (
@@ -366,18 +366,33 @@ async def _rewrite_query_for_search(self, query: str) -> str:
366366
:param query: The original user query
367367
:returns: The rewritten query optimized for vector search
368368
"""
369-
# Check if query expansion model is configured
370-
if not self.vector_stores_config:
371-
raise ValueError(
372-
f"No vector_stores_config found! self.vector_stores_config is: {self.vector_stores_config}"
373-
)
374-
if not self.vector_stores_config.default_query_expansion_model:
375-
raise ValueError(
376-
f"No default_query_expansion_model configured! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
377-
)
369+
expansion_model = None
370+
371+
# Check for per-store query expansion model first
372+
if self.vector_store.query_expansion_model:
373+
# Parse the model string into provider_id and model_id
374+
model_parts = self.vector_store.query_expansion_model.split("/", 1)
375+
if len(model_parts) == 2:
376+
expansion_model = QualifiedModel(provider_id=model_parts[0], model_id=model_parts[1])
377+
log.debug(f"Using per-store query expansion model: {expansion_model}")
378+
else:
379+
log.warning(
380+
f"Invalid query_expansion_model format: {self.vector_store.query_expansion_model}. Expected 'provider_id/model_id'"
381+
)
382+
383+
# Fall back to global default if no per-store model
384+
if not expansion_model:
385+
if not self.vector_stores_config:
386+
raise ValueError(
387+
f"No vector_stores_config found and no per-store query_expansion_model! self.vector_stores_config is: {self.vector_stores_config}"
388+
)
389+
if not self.vector_stores_config.default_query_expansion_model:
390+
raise ValueError(
391+
f"No default_query_expansion_model configured and no per-store query_expansion_model! vector_stores_config: {self.vector_stores_config}, default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
392+
)
393+
expansion_model = self.vector_stores_config.default_query_expansion_model
394+
log.debug(f"Using global default query expansion model: {expansion_model}")
378395

379-
# Use the configured model
380-
expansion_model = self.vector_stores_config.default_query_expansion_model
381396
chat_model = f"{expansion_model.provider_id}/{expansion_model.model_id}"
382397

383398
# Validate that the model is available and is an LLM

src/llama_stack_api/vector_stores.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class VectorStore(Resource):
2525
embedding_model: str
2626
embedding_dimension: int
2727
vector_store_name: str | None = None
28+
query_expansion_model: str | None = None
2829

2930
@property
3031
def vector_store_id(self) -> str:

tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,3 +1230,97 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
12301230

12311231
with pytest.raises(ValueError, match="embedding_model is required"):
12321232
await vector_io_adapter.openai_create_vector_store(params)
1233+
1234+
1235+
async def test_query_expansion_functionality(vector_io_adapter):
1236+
"""Test query expansion with per-store models, global defaults, and error validation."""
1237+
from unittest.mock import MagicMock
1238+
1239+
from llama_stack.core.datatypes import QualifiedModel, VectorStoresConfig
1240+
from llama_stack.providers.utils.memory.vector_store import VectorStoreWithIndex
1241+
from llama_stack_api.models import Model, ModelType
1242+
1243+
vector_io_adapter.register_vector_store = AsyncMock()
1244+
vector_io_adapter.__provider_id__ = "test_provider"
1245+
1246+
# Test 1: Per-store model usage
1247+
params = OpenAICreateVectorStoreRequestWithExtraBody(
1248+
name="test_store",
1249+
metadata={},
1250+
**{"embedding_model": "test/embedding", "query_expansion_model": "test/llama-model"},
1251+
)
1252+
await vector_io_adapter.openai_create_vector_store(params)
1253+
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
1254+
assert call_args.query_expansion_model == "test/llama-model"
1255+
1256+
# Test 2: Global default fallback
1257+
vector_io_adapter.register_vector_store.reset_mock()
1258+
params_no_model = OpenAICreateVectorStoreRequestWithExtraBody(
1259+
name="test_store2", metadata={}, **{"embedding_model": "test/embedding"}
1260+
)
1261+
await vector_io_adapter.openai_create_vector_store(params_no_model)
1262+
call_args2 = vector_io_adapter.register_vector_store.call_args[0][0]
1263+
assert call_args2.query_expansion_model is None
1264+
1265+
# Test query rewriting scenarios
1266+
mock_inference_api = MagicMock()
1267+
1268+
# Per-store model scenario
1269+
mock_vector_store = MagicMock()
1270+
mock_vector_store.query_expansion_model = "test/llama-model"
1271+
mock_inference_api.routing_table.list_models = AsyncMock(
1272+
return_value=MagicMock(
1273+
data=[Model(identifier="test/llama-model", provider_id="test", model_type=ModelType.llm)]
1274+
)
1275+
)
1276+
mock_inference_api.openai_chat_completion = AsyncMock(
1277+
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="per-store expanded"))])
1278+
)
1279+
1280+
vector_store_with_index = VectorStoreWithIndex(
1281+
vector_store=mock_vector_store,
1282+
index=MagicMock(),
1283+
inference_api=mock_inference_api,
1284+
vector_stores_config=VectorStoresConfig(
1285+
default_query_expansion_model=QualifiedModel(provider_id="global", model_id="default")
1286+
),
1287+
)
1288+
1289+
result = await vector_store_with_index._rewrite_query_for_search("test")
1290+
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "test/llama-model"
1291+
assert result == "per-store expanded"
1292+
1293+
# Global default fallback scenario
1294+
mock_inference_api.reset_mock()
1295+
mock_vector_store.query_expansion_model = None
1296+
mock_inference_api.routing_table.list_models = AsyncMock(
1297+
return_value=MagicMock(
1298+
data=[Model(identifier="global/default", provider_id="global", model_type=ModelType.llm)]
1299+
)
1300+
)
1301+
mock_inference_api.openai_chat_completion = AsyncMock(
1302+
return_value=MagicMock(choices=[MagicMock(message=MagicMock(content="global expanded"))])
1303+
)
1304+
1305+
result = await vector_store_with_index._rewrite_query_for_search("test")
1306+
assert mock_inference_api.openai_chat_completion.call_args[0][0].model == "global/default"
1307+
assert result == "global expanded"
1308+
1309+
# Test 3: Error cases
1310+
# Model not found
1311+
mock_vector_store.query_expansion_model = "missing/model"
1312+
mock_inference_api.routing_table.list_models = AsyncMock(return_value=MagicMock(data=[]))
1313+
1314+
with pytest.raises(ValueError, match="Configured query expansion model .* is not available"):
1315+
await vector_store_with_index._rewrite_query_for_search("test")
1316+
1317+
# Non-LLM model
1318+
mock_vector_store.query_expansion_model = "test/embedding-model"
1319+
mock_inference_api.routing_table.list_models = AsyncMock(
1320+
return_value=MagicMock(
1321+
data=[Model(identifier="test/embedding-model", provider_id="test", model_type=ModelType.embedding)]
1322+
)
1323+
)
1324+
1325+
with pytest.raises(ValueError, match="is not an LLM model.*Query rewriting requires an LLM model"):
1326+
await vector_store_with_index._rewrite_query_for_search("test")

0 commit comments

Comments
 (0)