Skip to content

Commit fcf649b

Browse files
authored
feat(storage): share sql/kv instances and add upsert support (#4140)
A few changes to the storage layer to ensure we reduce unnecessary contention arising out of our design choices (and letting the database layer do its correct thing): - SQL stores now share a single `SqlAlchemySqlStoreImpl` per backend, and `kvstore_impl` caches instances per `(backend, namespace)`. This avoids spawning multiple SQLite connections for the same file, reducing lock contention and aligning the cache story for all backends. - Added an async upsert API (with SQLite/Postgres dialect inserts) and routed it through `AuthorizedSqlStore`, then switched conversations and responses to call it. Using native `ON CONFLICT DO UPDATE` eliminates the insert-then-update retry window that previously caused long WAL lock retries. ### Test Plan Existing tests, added a unit test for `upsert()`
1 parent 492f79c commit fcf649b

File tree

8 files changed

+172
-51
lines changed

8 files changed

+172
-51
lines changed

src/llama_stack/core/conversations/conversations.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,11 @@ async def add_items(self, conversation_id: str, items: list[ConversationItem]) -
203203
"item_data": item_dict,
204204
}
205205

206-
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
207-
try:
208-
await self.sql_store.insert(table="conversation_items", data=item_record)
209-
except Exception:
210-
# If insert fails due to ID conflict, update existing record
211-
await self.sql_store.update(
212-
table="conversation_items",
213-
data={"created_at": created_at, "item_data": item_dict},
214-
where={"id": item_id},
215-
)
206+
await self.sql_store.upsert(
207+
table="conversation_items",
208+
data=item_record,
209+
conflict_columns=["id"],
210+
)
216211

217212
created_items.append(item_dict)
218213

src/llama_stack/providers/utils/kvstore/kvstore.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
from __future__ import annotations
1313

14+
import asyncio
15+
from collections import defaultdict
16+
1417
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
1518

1619
from .api import KVStore
@@ -53,45 +56,63 @@ async def delete(self, key: str) -> None:
5356

5457

5558
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
59+
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
60+
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
5661

5762

5863
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
5964
"""Register the set of available KV store backends for reference resolution."""
6065
global _KVSTORE_BACKENDS
66+
global _KVSTORE_INSTANCES
67+
global _KVSTORE_LOCKS
6168

6269
_KVSTORE_BACKENDS.clear()
70+
_KVSTORE_INSTANCES.clear()
71+
_KVSTORE_LOCKS.clear()
6372
for name, cfg in backends.items():
6473
_KVSTORE_BACKENDS[name] = cfg
6574

6675

6776
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
6877
backend_name = reference.backend
78+
cache_key = (backend_name, reference.namespace)
79+
80+
existing = _KVSTORE_INSTANCES.get(cache_key)
81+
if existing:
82+
return existing
6983

7084
backend_config = _KVSTORE_BACKENDS.get(backend_name)
7185
if backend_config is None:
7286
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
7387

74-
config = backend_config.model_copy()
75-
config.namespace = reference.namespace
88+
lock = _KVSTORE_LOCKS[cache_key]
89+
async with lock:
90+
existing = _KVSTORE_INSTANCES.get(cache_key)
91+
if existing:
92+
return existing
93+
94+
config = backend_config.model_copy()
95+
config.namespace = reference.namespace
7696

77-
if config.type == StorageBackendType.KV_REDIS.value:
78-
from .redis import RedisKVStoreImpl
97+
if config.type == StorageBackendType.KV_REDIS.value:
98+
from .redis import RedisKVStoreImpl
7999

80-
impl = RedisKVStoreImpl(config)
81-
elif config.type == StorageBackendType.KV_SQLITE.value:
82-
from .sqlite import SqliteKVStoreImpl
100+
impl = RedisKVStoreImpl(config)
101+
elif config.type == StorageBackendType.KV_SQLITE.value:
102+
from .sqlite import SqliteKVStoreImpl
83103

84-
impl = SqliteKVStoreImpl(config)
85-
elif config.type == StorageBackendType.KV_POSTGRES.value:
86-
from .postgres import PostgresKVStoreImpl
104+
impl = SqliteKVStoreImpl(config)
105+
elif config.type == StorageBackendType.KV_POSTGRES.value:
106+
from .postgres import PostgresKVStoreImpl
87107

88-
impl = PostgresKVStoreImpl(config)
89-
elif config.type == StorageBackendType.KV_MONGODB.value:
90-
from .mongodb import MongoDBKVStoreImpl
108+
impl = PostgresKVStoreImpl(config)
109+
elif config.type == StorageBackendType.KV_MONGODB.value:
110+
from .mongodb import MongoDBKVStoreImpl
91111

92-
impl = MongoDBKVStoreImpl(config)
93-
else:
94-
raise ValueError(f"Unknown kvstore type {config.type}")
112+
impl = MongoDBKVStoreImpl(config)
113+
else:
114+
raise ValueError(f"Unknown kvstore type {config.type}")
95115

96-
await impl.initialize()
97-
return impl
116+
await impl.initialize()
117+
_KVSTORE_INSTANCES[cache_key] = impl
118+
return impl

src/llama_stack/providers/utils/responses/responses_store.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -252,19 +252,12 @@ async def store_conversation_messages(self, conversation_id: str, messages: list
252252
# Serialize messages to dict format for JSON storage
253253
messages_data = [msg.model_dump() for msg in messages]
254254

255-
# Upsert: try insert first, update if exists
256-
try:
257-
await self.sql_store.insert(
258-
table="conversation_messages",
259-
data={"conversation_id": conversation_id, "messages": messages_data},
260-
)
261-
except Exception:
262-
# If insert fails due to ID conflict, update existing record
263-
await self.sql_store.update(
264-
table="conversation_messages",
265-
data={"messages": messages_data},
266-
where={"conversation_id": conversation_id},
267-
)
255+
await self.sql_store.upsert(
256+
table="conversation_messages",
257+
data={"conversation_id": conversation_id, "messages": messages_data},
258+
conflict_columns=["conversation_id"],
259+
update_columns=["messages"],
260+
)
268261

269262
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")
270263

src/llama_stack/providers/utils/sqlstore/api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[st
4747
"""
4848
pass
4949

50+
async def upsert(
51+
self,
52+
table: str,
53+
data: Mapping[str, Any],
54+
conflict_columns: list[str],
55+
update_columns: list[str] | None = None,
56+
) -> None:
57+
"""
58+
Insert a row and update specified columns when conflicts occur.
59+
"""
60+
pass
61+
5062
async def fetch_all(
5163
self,
5264
table: str,

src/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,23 @@ async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[st
129129
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
130130
await self.sql_store.insert(table, enhanced_data)
131131

132+
async def upsert(
133+
self,
134+
table: str,
135+
data: Mapping[str, Any],
136+
conflict_columns: list[str],
137+
update_columns: list[str] | None = None,
138+
) -> None:
139+
"""Upsert a row with automatic access control attribute capture."""
140+
current_user = get_authenticated_user()
141+
enhanced_data = _enhance_item_with_access_control(data, current_user)
142+
await self.sql_store.upsert(
143+
table=table,
144+
data=enhanced_data,
145+
conflict_columns=conflict_columns,
146+
update_columns=update_columns,
147+
)
148+
132149
async def fetch_all(
133150
self,
134151
table: str,

src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
7272
class SqlAlchemySqlStoreImpl(SqlStore):
7373
def __init__(self, config: SqlAlchemySqlStoreConfig):
7474
self.config = config
75+
self._is_sqlite_backend = "sqlite" in self.config.engine_str
7576
self.async_session = async_sessionmaker(self.create_engine())
7677
self.metadata = MetaData()
7778

7879
def create_engine(self) -> AsyncEngine:
7980
# Configure connection args for better concurrency support
8081
connect_args = {}
81-
if "sqlite" in self.config.engine_str:
82+
if self._is_sqlite_backend:
8283
# SQLite-specific optimizations for concurrent access
8384
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
8485
connect_args["timeout"] = 5.0
@@ -91,7 +92,7 @@ def create_engine(self) -> AsyncEngine:
9192
)
9293

9394
# Enable WAL mode for SQLite to support concurrent readers and writers
94-
if "sqlite" in self.config.engine_str:
95+
if self._is_sqlite_backend:
9596

9697
@event.listens_for(engine.sync_engine, "connect")
9798
def set_sqlite_pragma(dbapi_conn, connection_record):
@@ -151,6 +152,29 @@ async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[st
151152
await session.execute(self.metadata.tables[table].insert(), data)
152153
await session.commit()
153154

155+
async def upsert(
156+
self,
157+
table: str,
158+
data: Mapping[str, Any],
159+
conflict_columns: list[str],
160+
update_columns: list[str] | None = None,
161+
) -> None:
162+
table_obj = self.metadata.tables[table]
163+
dialect_insert = self._get_dialect_insert(table_obj)
164+
insert_stmt = dialect_insert.values(**data)
165+
166+
if update_columns is None:
167+
update_columns = [col for col in data.keys() if col not in conflict_columns]
168+
169+
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
170+
conflict_cols = [table_obj.c[col] for col in conflict_columns]
171+
172+
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
173+
174+
async with self.async_session() as session:
175+
await session.execute(stmt)
176+
await session.commit()
177+
154178
async def fetch_all(
155179
self,
156180
table: str,
@@ -333,9 +357,18 @@ def check_column_exists(sync_conn):
333357
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
334358

335359
await conn.execute(add_column_sql)
336-
337360
except Exception as e:
338361
# If any error occurs during migration, log it but don't fail
339362
# The table creation will handle adding the column
340363
logger.error(f"Error adding column {column_name} to table {table}: {e}")
341364
pass
365+
366+
def _get_dialect_insert(self, table: Table):
367+
if self._is_sqlite_backend:
368+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
369+
370+
return sqlite_insert(table)
371+
else:
372+
from sqlalchemy.dialects.postgresql import insert as pg_insert
373+
374+
return pg_insert(table)

src/llama_stack/providers/utils/sqlstore/sqlstore.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from threading import Lock
78
from typing import Annotated, cast
89

910
from pydantic import Field
@@ -21,6 +22,8 @@
2122
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
2223

2324
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
25+
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
26+
_SQLSTORE_LOCKS: dict[str, Lock] = {}
2427

2528

2629
SqlStoreConfig = Annotated[
@@ -52,19 +55,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
5255
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
5356
)
5457

55-
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
56-
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
58+
existing = _SQLSTORE_INSTANCES.get(backend_name)
59+
if existing:
60+
return existing
5761

58-
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
59-
return SqlAlchemySqlStoreImpl(config)
60-
else:
61-
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
62+
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
63+
with lock:
64+
existing = _SQLSTORE_INSTANCES.get(backend_name)
65+
if existing:
66+
return existing
67+
68+
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
69+
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
70+
71+
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
72+
instance = SqlAlchemySqlStoreImpl(config)
73+
_SQLSTORE_INSTANCES[backend_name] = instance
74+
return instance
75+
else:
76+
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
6277

6378

6479
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
6580
"""Register the set of available SQL store backends for reference resolution."""
6681
global _SQLSTORE_BACKENDS
82+
global _SQLSTORE_INSTANCES
6783

6884
_SQLSTORE_BACKENDS.clear()
85+
_SQLSTORE_INSTANCES.clear()
86+
_SQLSTORE_LOCKS.clear()
6987
for name, cfg in backends.items():
7088
_SQLSTORE_BACKENDS[name] = cfg

tests/unit/utils/sqlstore/test_sqlstore.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from llama_stack.providers.utils.sqlstore.api import ColumnType
12+
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
1313
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
1414
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
1515

@@ -65,6 +65,38 @@ async def test_sqlite_sqlstore():
6565
assert result.has_more is False
6666

6767

68+
async def test_sqlstore_upsert_support():
69+
with TemporaryDirectory() as tmp_dir:
70+
db_path = tmp_dir + "/upsert.db"
71+
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
72+
73+
await store.create_table(
74+
"items",
75+
{
76+
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
77+
"value": ColumnType.STRING,
78+
"updated_at": ColumnType.INTEGER,
79+
},
80+
)
81+
82+
await store.upsert(
83+
table="items",
84+
data={"id": "item_1", "value": "first", "updated_at": 1},
85+
conflict_columns=["id"],
86+
)
87+
row = await store.fetch_one("items", {"id": "item_1"})
88+
assert row == {"id": "item_1", "value": "first", "updated_at": 1}
89+
90+
await store.upsert(
91+
table="items",
92+
data={"id": "item_1", "value": "second", "updated_at": 2},
93+
conflict_columns=["id"],
94+
update_columns=["value", "updated_at"],
95+
)
96+
row = await store.fetch_one("items", {"id": "item_1"})
97+
assert row == {"id": "item_1", "value": "second", "updated_at": 2}
98+
99+
68100
async def test_sqlstore_pagination_basic():
69101
"""Test basic pagination functionality at the SQL store level."""
70102
with TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)