Skip to content

Commit 08a1c4e

Browse files
refactor to only configuration of model at build time
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 376f16e commit 08a1c4e

File tree

22 files changed

+255
-301
lines changed

22 files changed

+255
-301
lines changed

src/llama_stack/core/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
# Default prompt template for query expansion in vector search
8+
DEFAULT_QUERY_EXPANSION_PROMPT = "Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:"

src/llama_stack/core/datatypes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic import BaseModel, Field, field_validator, model_validator
1313

1414
from llama_stack.core.access_control.datatypes import AccessRule
15+
from llama_stack.core.constants import DEFAULT_QUERY_EXPANSION_PROMPT
1516
from llama_stack.core.storage.datatypes import (
1617
KVStoreReference,
1718
StorageBackendType,
@@ -381,9 +382,17 @@ class VectorStoresConfig(BaseModel):
381382
description="Default LLM model for query expansion/rewriting in vector search.",
382383
)
383384
query_expansion_prompt: str = Field(
384-
default="Expand this query with relevant synonyms and related terms. Return only the improved query, no explanations:\n\n{query}\n\nImproved query:",
385+
default=DEFAULT_QUERY_EXPANSION_PROMPT,
385386
description="Prompt template for query expansion. Use {query} as placeholder for the original query.",
386387
)
388+
query_expansion_max_tokens: int = Field(
389+
default=100,
390+
description="Maximum number of tokens for query expansion responses.",
391+
)
392+
query_expansion_temperature: float = Field(
393+
default=0.3,
394+
description="Temperature for query expansion model (0.0 = deterministic, 1.0 = creative).",
395+
)
387396

388397

389398
class SafetyConfig(BaseModel):

src/llama_stack/core/resolver.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -374,13 +374,6 @@ async def instantiate_provider(
374374
method = "get_adapter_impl"
375375
args = [config, deps]
376376

377-
# Add vector_stores_config for vector_io providers
378-
if (
379-
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
380-
and provider_spec.api == Api.vector_io
381-
):
382-
args.append(run_config.vector_stores)
383-
384377
elif isinstance(provider_spec, AutoRoutedProviderSpec):
385378
method = "get_auto_router_impl"
386379

@@ -401,11 +394,6 @@ async def instantiate_provider(
401394
args.append(policy)
402395
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
403396
args.append(run_config.telemetry.enabled)
404-
if (
405-
"vector_stores_config" in inspect.signature(getattr(module, method)).parameters
406-
and provider_spec.api == Api.vector_io
407-
):
408-
args.append(run_config.vector_stores)
409397

410398
fn = getattr(module, method)
411399
impl = await fn(*args)

src/llama_stack/core/stack.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from llama_stack.core.store.registry import create_dist_registry
3535
from llama_stack.core.utils.dynamic import instantiate_class_type
3636
from llama_stack.log import get_logger
37+
from llama_stack.providers.utils.memory.vector_store import set_default_query_expansion_config
3738
from llama_stack_api import (
3839
Agents,
3940
Api,
@@ -144,35 +145,62 @@ async def validate_vector_stores_config(vector_stores_config: VectorStoresConfig
144145
if vector_stores_config is None:
145146
return
146147

148+
# Validate default embedding model
147149
default_embedding_model = vector_stores_config.default_embedding_model
148-
if default_embedding_model is None:
149-
return
150+
if default_embedding_model is not None:
151+
provider_id = default_embedding_model.provider_id
152+
model_id = default_embedding_model.model_id
153+
default_model_id = f"{provider_id}/{model_id}"
150154

151-
provider_id = default_embedding_model.provider_id
152-
model_id = default_embedding_model.model_id
153-
default_model_id = f"{provider_id}/{model_id}"
155+
if Api.models not in impls:
156+
raise ValueError(
157+
f"Models API is not available but vector_stores config requires model '{default_model_id}'"
158+
)
154159

155-
if Api.models not in impls:
156-
raise ValueError(f"Models API is not available but vector_stores config requires model '{default_model_id}'")
160+
models_impl = impls[Api.models]
161+
response = await models_impl.list_models()
162+
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
157163

158-
models_impl = impls[Api.models]
159-
response = await models_impl.list_models()
160-
models_list = {m.identifier: m for m in response.data if m.model_type == "embedding"}
164+
default_model = models_list.get(default_model_id)
165+
if default_model is None:
166+
raise ValueError(
167+
f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}"
168+
)
161169

162-
default_model = models_list.get(default_model_id)
163-
if default_model is None:
164-
raise ValueError(f"Embedding model '{default_model_id}' not found. Available embedding models: {models_list}")
170+
embedding_dimension = default_model.metadata.get("embedding_dimension")
171+
if embedding_dimension is None:
172+
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
165173

166-
embedding_dimension = default_model.metadata.get("embedding_dimension")
167-
if embedding_dimension is None:
168-
raise ValueError(f"Embedding model '{default_model_id}' is missing 'embedding_dimension' in metadata")
174+
try:
175+
int(embedding_dimension)
176+
except ValueError as err:
177+
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
169178

170-
try:
171-
int(embedding_dimension)
172-
except ValueError as err:
173-
raise ValueError(f"Embedding dimension '{embedding_dimension}' cannot be converted to an integer") from err
179+
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
174180

175-
logger.debug(f"Validated default embedding model: {default_model_id} (dimension: {embedding_dimension})")
181+
# Validate default query expansion model
182+
default_query_expansion_model = vector_stores_config.default_query_expansion_model
183+
if default_query_expansion_model is not None:
184+
provider_id = default_query_expansion_model.provider_id
185+
model_id = default_query_expansion_model.model_id
186+
query_model_id = f"{provider_id}/{model_id}"
187+
188+
if Api.models not in impls:
189+
raise ValueError(
190+
f"Models API is not available but vector_stores config requires query expansion model '{query_model_id}'"
191+
)
192+
193+
models_impl = impls[Api.models]
194+
response = await models_impl.list_models()
195+
llm_models_list = {m.identifier: m for m in response.data if m.model_type == "llm"}
196+
197+
query_expansion_model = llm_models_list.get(query_model_id)
198+
if query_expansion_model is None:
199+
raise ValueError(
200+
f"Query expansion model '{query_model_id}' not found. Available LLM models: {list(llm_models_list.keys())}"
201+
)
202+
203+
logger.debug(f"Validated default query expansion model: {query_model_id}")
176204

177205

178206
async def validate_safety_config(safety_config: SafetyConfig | None, impls: dict[Api, Any]):
@@ -437,6 +465,10 @@ async def initialize(self):
437465
await refresh_registry_once(impls)
438466
await validate_vector_stores_config(self.run_config.vector_stores, impls)
439467
await validate_safety_config(self.run_config.safety, impls)
468+
469+
# Set global query expansion configuration
470+
set_default_query_expansion_config(self.run_config.vector_stores)
471+
440472
self.impls = impls
441473

442474
def create_registry_refresh_task(self):

src/llama_stack/providers/inline/vector_io/faiss/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,16 @@
66

77
from typing import Any
88

9-
from llama_stack.core.datatypes import VectorStoresConfig
109
from llama_stack_api import Api
1110

1211
from .config import FaissVectorIOConfig
1312

1413

15-
async def get_provider_impl(
16-
config: FaissVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
17-
):
14+
async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
1815
from .faiss import FaissVectorIOAdapter
1916

2017
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
2118

22-
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
19+
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
2320
await impl.initialize()
2421
return impl

src/llama_stack/providers/inline/vector_io/faiss/faiss.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
from numpy.typing import NDArray
1616

17-
from llama_stack.core.datatypes import VectorStoresConfig
1817
from llama_stack.core.storage.kvstore import kvstore_impl
1918
from llama_stack.log import get_logger
2019
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@@ -190,12 +189,10 @@ def __init__(
190189
config: FaissVectorIOConfig,
191190
inference_api: Inference,
192191
files_api: Files | None,
193-
vector_stores_config: VectorStoresConfig | None = None,
194192
) -> None:
195193
super().__init__(files_api=files_api, kvstore=None)
196194
self.config = config
197195
self.inference_api = inference_api
198-
self.vector_stores_config = vector_stores_config
199196
self.cache: dict[str, VectorStoreWithIndex] = {}
200197

201198
async def initialize(self) -> None:
@@ -211,7 +208,6 @@ async def initialize(self) -> None:
211208
vector_store,
212209
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
213210
self.inference_api,
214-
self.vector_stores_config,
215211
)
216212
self.cache[vector_store.identifier] = index
217213

@@ -250,7 +246,6 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
250246
vector_store=vector_store,
251247
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
252248
inference_api=self.inference_api,
253-
vector_stores_config=self.vector_stores_config,
254249
)
255250

256251
async def list_vector_stores(self) -> list[VectorStore]:
@@ -284,7 +279,6 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
284279
vector_store=vector_store,
285280
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
286281
inference_api=self.inference_api,
287-
vector_stores_config=self.vector_stores_config,
288282
)
289283
self.cache[vector_store_id] = index
290284
return index

src/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66

77
from typing import Any
88

9-
from llama_stack.core.datatypes import VectorStoresConfig
109
from llama_stack_api import Api
1110

1211
from .config import SQLiteVectorIOConfig
1312

1413

15-
async def get_provider_impl(
16-
config: SQLiteVectorIOConfig, deps: dict[Api, Any], vector_stores_config: VectorStoresConfig | None = None
17-
):
14+
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
1815
from .sqlite_vec import SQLiteVecVectorIOAdapter
1916

2017
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
21-
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
18+
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
2219
await impl.initialize()
2320
return impl

src/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import sqlite_vec # type: ignore[import-untyped]
1515
from numpy.typing import NDArray
1616

17-
from llama_stack.core.datatypes import VectorStoresConfig
1817
from llama_stack.core.storage.kvstore import kvstore_impl
1918
from llama_stack.log import get_logger
2019
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@@ -391,12 +390,10 @@ def __init__(
391390
config,
392391
inference_api: Inference,
393392
files_api: Files | None,
394-
vector_stores_config: VectorStoresConfig | None = None,
395393
) -> None:
396394
super().__init__(files_api=files_api, kvstore=None)
397395
self.config = config
398396
self.inference_api = inference_api
399-
self.vector_stores_config = vector_stores_config
400397
self.cache: dict[str, VectorStoreWithIndex] = {}
401398
self.vector_store_table = None
402399

@@ -411,9 +408,7 @@ async def initialize(self) -> None:
411408
index = await SQLiteVecIndex.create(
412409
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
413410
)
414-
self.cache[vector_store.identifier] = VectorStoreWithIndex(
415-
vector_store, index, self.inference_api, self.vector_stores_config
416-
)
411+
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
417412

418413
# Load existing OpenAI vector stores into the in-memory cache
419414
await self.initialize_openai_vector_stores()
@@ -437,9 +432,7 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
437432
index = await SQLiteVecIndex.create(
438433
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
439434
)
440-
self.cache[vector_store.identifier] = VectorStoreWithIndex(
441-
vector_store, index, self.inference_api, self.vector_stores_config
442-
)
435+
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)
443436

444437
async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
445438
if vector_store_id in self.cache:
@@ -464,7 +457,6 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
464457
kvstore=self.kvstore,
465458
),
466459
inference_api=self.inference_api,
467-
vector_stores_config=self.vector_stores_config,
468460
)
469461
self.cache[vector_store_id] = index
470462
return index

src/llama_stack/providers/remote/vector_io/chroma/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.core.datatypes import VectorStoresConfig
87
from llama_stack_api import Api, ProviderSpec
98

109
from .config import ChromaVectorIOConfig
1110

1211

13-
async def get_adapter_impl(
14-
config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec], vector_stores_config: VectorStoresConfig | None = None
15-
):
12+
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
1613
from .chroma import ChromaVectorIOAdapter
1714

18-
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files), vector_stores_config)
15+
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
1916
await impl.initialize()
2017
return impl

src/llama_stack/providers/remote/vector_io/chroma/chroma.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import chromadb
1212
from numpy.typing import NDArray
1313

14-
from llama_stack.core.datatypes import VectorStoresConfig
1514
from llama_stack.core.storage.kvstore import kvstore_impl
1615
from llama_stack.log import get_logger
1716
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@@ -126,13 +125,11 @@ def __init__(
126125
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
127126
inference_api: Inference,
128127
files_api: Files | None,
129-
vector_stores_config: VectorStoresConfig | None = None,
130128
) -> None:
131129
super().__init__(files_api=files_api, kvstore=None)
132130
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
133131
self.config = config
134132
self.inference_api = inference_api
135-
self.vector_stores_config = vector_stores_config
136133
self.client = None
137134
self.cache = {}
138135
self.vector_store_table = None
@@ -165,7 +162,7 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
165162
)
166163
)
167164
self.cache[vector_store.identifier] = VectorStoreWithIndex(
168-
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
165+
vector_store, ChromaIndex(self.client, collection), self.inference_api
169166
)
170167

171168
async def unregister_vector_store(self, vector_store_id: str) -> None:
@@ -210,9 +207,7 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
210207
collection = await maybe_await(self.client.get_collection(vector_store_id))
211208
if not collection:
212209
raise ValueError(f"Vector DB {vector_store_id} not found in Chroma")
213-
index = VectorStoreWithIndex(
214-
vector_store, ChromaIndex(self.client, collection), self.inference_api, self.vector_stores_config
215-
)
210+
index = VectorStoreWithIndex(vector_store, ChromaIndex(self.client, collection), self.inference_api)
216211
self.cache[vector_store_id] = index
217212
return index
218213

0 commit comments

Comments
 (0)