From 1e771f3968ededdd4300c438b111924bfe367810 Mon Sep 17 00:00:00 2001 From: Robert Scott Date: Tue, 21 Oct 2025 21:18:00 +0100 Subject: [PATCH] query recording: also record bind_key of query --- docs/record-queries.rst | 2 ++ src/flask_sqlalchemy/extension.py | 4 ++-- src/flask_sqlalchemy/record_queries.py | 18 +++++++++++++++--- tests/test_record_queries.py | 1 + 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/docs/record-queries.rst b/docs/record-queries.rst index afc43481..53e599c1 100644 --- a/docs/record-queries.rst +++ b/docs/record-queries.rst @@ -25,3 +25,5 @@ object has the following attributes: ``location`` A string description of where in your application code the query was executed. This may be unknown in certain cases. +``bind_key`` + The bind key of the engine which issued the query. diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index ccae54b4..ed5d8a5a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -378,8 +378,8 @@ def init_app(self, app: Flask) -> None: if app.config.setdefault("SQLALCHEMY_RECORD_QUERIES", False): from . import record_queries - for engine in engines.values(): - record_queries._listen(engine) + for bind_key, engine in engines.items(): + record_queries._listen(bind_key, engine) if app.config.setdefault("SQLALCHEMY_TRACK_MODIFICATIONS", False): from . import track_modifications diff --git a/src/flask_sqlalchemy/record_queries.py b/src/flask_sqlalchemy/record_queries.py index e8273be9..ef11f97e 100644 --- a/src/flask_sqlalchemy/record_queries.py +++ b/src/flask_sqlalchemy/record_queries.py @@ -3,6 +3,7 @@ import dataclasses import inspect import typing as t +from functools import partial from time import perf_counter import sqlalchemy as sa @@ -65,15 +66,21 @@ class _QueryInfo: start_time: float end_time: float location: str + bind_key: str | None @property def duration(self) -> float: return self.end_time - self.start_time -def _listen(engine: sa.engine.Engine) -> None: +def _listen(bind_key: str | None, engine: sa.engine.Engine) -> None: sa_event.listen(engine, "before_cursor_execute", _record_start, named=True) - sa_event.listen(engine, "after_cursor_execute", _record_end, named=True) + sa_event.listen( + engine, + "after_cursor_execute", + partial(_record_end, bind_key), + named=True, + ) def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: @@ -83,7 +90,11 @@ def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: context._fsa_start_time = perf_counter() # type: ignore[attr-defined] -def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: +def _record_end( + bind_key: str | None, + context: sa.engine.ExecutionContext, + **kwargs: t.Any, +) -> None: if not has_app_context(): return @@ -113,5 +124,6 @@ def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None: start_time=context._fsa_start_time, # type: ignore[attr-defined] end_time=perf_counter(), location=location, + bind_key=bind_key, ) ) diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index c5cc73a2..22df73d4 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -49,3 +49,4 @@ class Todo(db.Model): # type: ignore[no-redef] assert info.duration == info.end_time - info.start_time assert os.path.join("tests", "test_record_queries.py:") in info.location assert "(test_query_info)" in info.location + assert info.bind_key is None