Skip to content

Commit 685eb05

Browse files
adding config to providers so that it can properly be used
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 2b9066c commit 685eb05

File tree

18 files changed

+166
-30
lines changed

18 files changed

+166
-30
lines changed

src/llama_stack/core/resolver.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ async def instantiate_provider(
374374
method = "get_adapter_impl"
375375
args = [config, deps]
376376

377+
# Add vector_stores_config for vector_io providers
378+
if (
379+
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
380+
and provider_spec.api == Api.vector_io
381+
):
382+
args.append(run_config.vector_stores)
383+
377384
elif isinstance(provider_spec, AutoRoutedProviderSpec):
378385
method = "get_auto_router_impl"
379386

@@ -394,6 +401,11 @@ async def instantiate_provider(
394401
args.append(policy)
395402
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
396403
args.append(run_config.telemetry.enabled)
404+
if (
405+
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
406+
and provider_spec.api == Api.vector_io
407+
):
408+
args.append(run_config.vector_stores)
397409

398410
fn = getattr(module, method)
399411
impl = await fn(*args)

src/llama_stack/core/routers/vector_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ async def query_chunks(
103103
# Ensure params dict exists and add vector_stores_config for query rewriting
104104
if params is None:
105105
params = {}
106+
107+
logger.debug(f"Router vector_stores_config: {self.vector_stores_config}")
108+
if self.vector_stores_config and hasattr(self.vector_stores_config, "default_query_expansion_model"):
109+
logger.debug(
110+
f"Router default_query_expansion_model: {self.vector_stores_config.default_query_expansion_model}"
111+
)
112+
106113
params["vector_stores_config"] = self.vector_stores_config
107114

108115
return await provider.query_chunks(vector_store_id, query, params)

src/llama_stack/providers/inline/vector_io/faiss/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66

77
from typing import Any
88

9+
from llama_stack.core.datatypes import VectorStoresConfig
910
from llama_stack_api import Api
1011

1112
from .config import FaissVectorIOConfig
1213

1314

14-
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
15+
async def get_provider_impl(
16+
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
17+
):
1518
from .faiss import FaissVectorIOAdapter
1619

1720
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
1821

19-
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
22+
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
2023
await impl.initialize()
2124
return impl

src/llama_stack/providers/inline/vector_io/faiss/faiss.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
from numpy.typing import NDArray
1616

17+
from llama_stack.core.datatypes import VectorStoresConfig
1718
from llama_stack.log import get_logger
1819
from llama_stack.providers.utils.kvstore import kvstore_impl
1920
from llama_stack.providers.utils.kvstore.api import KVStore
@@ -184,10 +185,17 @@ async def query_hybrid(
184185

185186

186187
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
187-
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
188+
def __init__(
189+
self,
190+
config: FaissVectorIOConfig,
191+
inference_api: Inference,
192+
files_api: Files | None,
193+
vector_stores_config: VectorStoresConfig | None = None,
194+
) -> None:
188195
super().__init__(files_api=files_api, kvstore=None)
189196
self.config = config
190197
self.inference_api = inference_api
198+
self.vector_stores_config = vector_stores_config
191199
self.cache: dict[str, VectorStoreWithIndex] = {}
192200

193201
async def initialize(self) -> None:
@@ -203,6 +211,7 @@ async def initialize(self) -> None:
203211
vector_store,
204212
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
205213
self.inference_api,
214+
self.vector_stores_config,
206215
)
207216
self.cache[vector_store.identifier] = index
208217

@@ -241,6 +250,7 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
241250
vector_store=vector_store,
242251
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
243252
inference_api=self.inference_api,
253+
vector_stores_config=self.vector_stores_config,
244254
)
245255

246256
async def list_vector_stores(self) -> list[VectorStore]:
@@ -274,6 +284,7 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
274284
vector_store=vector_store,
275285
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
276286
inference_api=self.inference_api,
287+
vector_stores_config=self.vector_stores_config,
277288
)
278289
self.cache[vector_store_id] = index
279290
return index

src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66

77
from typing import Any
88

9+
from llama_stack.core.datatypes import VectorStoresConfig
910
from llama_stack_api import Api
1011

1112
from .config import SQLiteVectorIOConfig
1213

1314

14-
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
15+
async def get_provider_impl(
16+
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
17+
):
1518
from .sqlite_vec import SQLiteVecVectorIOAdapter
1619

1720
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
18-
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
21+
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
1922
await impl.initialize()
2023
return impl

src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sqlite_vec # type: ignore[import-untyped]
1515
from numpy.typing import NDArray
1616

17+
from llama_stack.core.datatypes import VectorStoresConfig
1718
from llama_stack.log import get_logger
1819
from llama_stack.providers.utils.kvstore import kvstore_impl
1920
from llama_stack.providers.utils.kvstore.api import KVStore
@@ -385,10 +386,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresPro
385386
and creates a cache of VectorStoreWithIndex instances (each wrapping a SQLiteVecIndex).
386387
"""
387388

388-
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
389+
def __init__(
390+
self,
391+
config,
392+
inference_api: Inference,
393+
files_api: Files | None,
394+
vector_stores_config: VectorStoresConfig | None = None,
395+
) -> None:
389396
super().__init__(files_api=files_api, kvstore=None)
390397
self.config = config
391398
self.inference_api = inference_api
399+
self.vector_stores_config = vector_stores_config
392400
self.cache: dict[str, VectorStoreWithIndex] = {}
393401
self.vector_store_table = None
394402

@@ -403,7 +411,9 @@ async def initialize(self) -> None:
403411
index = await SQLiteVecIndex.create(
404412
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
405413
)
406-
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
414+
self.cache[vector_store.identifier] = VectorStoreWithIndex(
415+
vector_store, index, self.inference_api, self.vector_stores_config
416+
)
407417

408418
# Load existing OpenAI vector stores into the in-memory cache
409419
await self.initialize_openai_vector_stores()
@@ -427,7 +437,9 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
427437
index = await SQLiteVecIndex.create(
428438
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
429439
)
430-
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
440+
self.cache[vector_store.identifier] = VectorStoreWithIndex(
441+
vector_store, index, self.inference_api, self.vector_stores_config
442+
)
431443

432444
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
433445
if vector_store_id in self.cache:
@@ -452,6 +464,7 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
452464
kvstore=self.kvstore,
453465
),
454466
inference_api=self.inference_api,
467+
vector_stores_config=self.vector_stores_config,
455468
)
456469
self.cache[vector_store_id] = index
457470
return index

src/llama_stack/providers/remote/vector_io/chroma/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from llama_stack.core.datatypes import VectorStoresConfig
78
from llama_stack_api import Api, ProviderSpec
89

910
from .config import ChromaVectorIOConfig
1011

1112

12-
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
13+
async def get_adapter_impl(
14+
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
15+
):
1316
from .chroma import ChromaVectorIOAdapter
1417

15-
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
18+
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
1619
await impl.initialize()
1720
return impl

src/llama_stack/providers/remote/vector_io/chroma/chroma.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import chromadb
1212
from numpy.typing import NDArray
1313

14+
from llama_stack.core.datatypes import VectorStoresConfig
1415
from llama_stack.log import get_logger
1516
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
1617
from llama_stack.providers.utils.kvstore import kvstore_impl
@@ -125,11 +126,13 @@ def __init__(
125126
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
126127
inference_api: Inference,
127128
files_api: Files | None,
129+
vector_stores_config: VectorStoresConfig | None = None,
128130
) -> None:
129131
super().__init__(files_api=files_api, kvstore=None)
130132
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
131133
self.config = config
132134
self.inference_api = inference_api
135+
self.vector_stores_config = vector_stores_config
133136
self.client = None
134137
self.cache = {}
135138
self.vector_store_table = None
@@ -162,7 +165,7 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
162165
)
163166
)
164167
self.cache[vector_store.identifier] = VectorStoreWithIndex(
165-
vector_store, ChromaIndex(self.client, collection), self.inference_api
168+
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
166169
)
167170

168171
async def unregister_vector_store(self, vector_store_id: str) -> None:
@@ -207,7 +210,9 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
207210
collection = await maybe_await(self.client.get_collection(vector_store_id))
208211
if not collection:
209212
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
210-
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
213+
index = VectorStoreWithIndex(
214+
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
215+
)
211216
self.cache[vector_store_id] = index
212217
return index
213218

src/llama_stack/providers/remote/vector_io/milvus/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from llama_stack.core.datatypes import VectorStoresConfig
78
from llama_stack_api import Api, ProviderSpec
89

910
from .config import MilvusVectorIOConfig
1011

1112

12-
async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec]):
13+
async def get_adapter_impl(
14+
config: MilvusVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
15+
):
1316
from .milvus import MilvusVectorIOAdapter
1417

1518
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
16-
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
19+
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
1720
await impl.initialize()
1821
return impl

src/llama_stack/providers/remote/vector_io/milvus/milvus.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from numpy.typing import NDArray
1212
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
1313

14+
from llama_stack.core.datatypes import VectorStoresConfig
1415
from llama_stack.log import get_logger
1516
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
1617
from llama_stack.providers.utils.kvstore import kvstore_impl
@@ -272,12 +273,14 @@ def __init__(
272273
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
273274
inference_api: Inference,
274275
files_api: Files | None,
276+
vector_stores_config: VectorStoresConfig | None = None,
275277
) -> None:
276278
super().__init__(files_api=files_api, kvstore=None)
277279
self.config = config
278280
self.cache = {}
279281
self.client = None
280282
self.inference_api = inference_api
283+
self.vector_stores_config = vector_stores_config
281284
self.vector_store_table = None
282285
self.metadata_collection_name = "openai_vector_stores_metadata"
283286

@@ -298,6 +301,7 @@ async def initialize(self) -> None:
298301
kvstore=self.kvstore,
299302
),
300303
inference_api=self.inference_api,
304+
vector_stores_config=self.vector_stores_config,
301305
)
302306
self.cache[vector_store.identifier] = index
303307
if isinstance(self.config, RemoteMilvusVectorIOConfig):
@@ -325,6 +329,7 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
325329
vector_store=vector_store,
326330
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
327331
inference_api=self.inference_api,
332+
vector_stores_config=self.vector_stores_config,
328333
)
329334

330335
self.cache[vector_store.identifier] = index
@@ -347,6 +352,7 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
347352
vector_store=vector_store,
348353
index=MilvusIndex(client=self.client, collection_name=vector_store.identifier, kvstore=self.kvstore),
349354
inference_api=self.inference_api,
355+
vector_stores_config=self.vector_stores_config,
350356
)
351357
self.cache[vector_store_id] = index
352358
return index

0 commit comments

Comments
 (0)