@@ -1230,3 +1230,97 @@ async def test_embedding_config_required_model_missing(vector_io_adapter):
12301230
12311231 with pytest .raises (ValueError , match = "embedding_model is required" ):
12321232 await vector_io_adapter .openai_create_vector_store (params )
1233+
1234+
1235+ async def test_query_expansion_functionality (vector_io_adapter ):
1236+ """Test query expansion with per-store models, global defaults, and error validation."""
1237+ from unittest .mock import MagicMock
1238+
1239+ from llama_stack .core .datatypes import QualifiedModel , VectorStoresConfig
1240+ from llama_stack .providers .utils .memory .vector_store import VectorStoreWithIndex
1241+ from llama_stack_api .models import Model , ModelType
1242+
1243+ vector_io_adapter .register_vector_store = AsyncMock ()
1244+ vector_io_adapter .__provider_id__ = "test_provider"
1245+
1246+ # Test 1: Per-store model usage
1247+ params = OpenAICreateVectorStoreRequestWithExtraBody (
1248+ name = "test_store" ,
1249+ metadata = {},
1250+ ** {"embedding_model" : "test/embedding" , "query_expansion_model" : "test/llama-model" },
1251+ )
1252+ await vector_io_adapter .openai_create_vector_store (params )
1253+ call_args = vector_io_adapter .register_vector_store .call_args [0 ][0 ]
1254+ assert call_args .query_expansion_model == "test/llama-model"
1255+
1256+ # Test 2: Global default fallback
1257+ vector_io_adapter .register_vector_store .reset_mock ()
1258+ params_no_model = OpenAICreateVectorStoreRequestWithExtraBody (
1259+ name = "test_store2" , metadata = {}, ** {"embedding_model" : "test/embedding" }
1260+ )
1261+ await vector_io_adapter .openai_create_vector_store (params_no_model )
1262+ call_args2 = vector_io_adapter .register_vector_store .call_args [0 ][0 ]
1263+ assert call_args2 .query_expansion_model is None
1264+
1265+ # Test query rewriting scenarios
1266+ mock_inference_api = MagicMock ()
1267+
1268+ # Per-store model scenario
1269+ mock_vector_store = MagicMock ()
1270+ mock_vector_store .query_expansion_model = "test/llama-model"
1271+ mock_inference_api .routing_table .list_models = AsyncMock (
1272+ return_value = MagicMock (
1273+ data = [Model (identifier = "test/llama-model" , provider_id = "test" , model_type = ModelType .llm )]
1274+ )
1275+ )
1276+ mock_inference_api .openai_chat_completion = AsyncMock (
1277+ return_value = MagicMock (choices = [MagicMock (message = MagicMock (content = "per-store expanded" ))])
1278+ )
1279+
1280+ vector_store_with_index = VectorStoreWithIndex (
1281+ vector_store = mock_vector_store ,
1282+ index = MagicMock (),
1283+ inference_api = mock_inference_api ,
1284+ vector_stores_config = VectorStoresConfig (
1285+ default_query_expansion_model = QualifiedModel (provider_id = "global" , model_id = "default" )
1286+ ),
1287+ )
1288+
1289+ result = await vector_store_with_index ._rewrite_query_for_search ("test" )
1290+ assert mock_inference_api .openai_chat_completion .call_args [0 ][0 ].model == "test/llama-model"
1291+ assert result == "per-store expanded"
1292+
1293+ # Global default fallback scenario
1294+ mock_inference_api .reset_mock ()
1295+ mock_vector_store .query_expansion_model = None
1296+ mock_inference_api .routing_table .list_models = AsyncMock (
1297+ return_value = MagicMock (
1298+ data = [Model (identifier = "global/default" , provider_id = "global" , model_type = ModelType .llm )]
1299+ )
1300+ )
1301+ mock_inference_api .openai_chat_completion = AsyncMock (
1302+ return_value = MagicMock (choices = [MagicMock (message = MagicMock (content = "global expanded" ))])
1303+ )
1304+
1305+ result = await vector_store_with_index ._rewrite_query_for_search ("test" )
1306+ assert mock_inference_api .openai_chat_completion .call_args [0 ][0 ].model == "global/default"
1307+ assert result == "global expanded"
1308+
1309+ # Test 3: Error cases
1310+ # Model not found
1311+ mock_vector_store .query_expansion_model = "missing/model"
1312+ mock_inference_api .routing_table .list_models = AsyncMock (return_value = MagicMock (data = []))
1313+
1314+ with pytest .raises (ValueError , match = "Configured query expansion model .* is not available" ):
1315+ await vector_store_with_index ._rewrite_query_for_search ("test" )
1316+
1317+ # Non-LLM model
1318+ mock_vector_store .query_expansion_model = "test/embedding-model"
1319+ mock_inference_api .routing_table .list_models = AsyncMock (
1320+ return_value = MagicMock (
1321+ data = [Model (identifier = "test/embedding-model" , provider_id = "test" , model_type = ModelType .embedding )]
1322+ )
1323+ )
1324+
1325+ with pytest .raises (ValueError , match = "is not an LLM model.*Query rewriting requires an LLM model" ):
1326+ await vector_store_with_index ._rewrite_query_for_search ("test" )
0 commit comments