Skip to content

Commit 9717ad6

Browse files
committed
Add get_dependencies function
1 parent dd106e7 commit 9717ad6

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

debug_toolbar/dependencies.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import typing as t
2+
from contextlib import AsyncExitStack
3+
4+
from fastapi import HTTPException, Request
5+
from fastapi.dependencies.utils import solve_dependencies
6+
7+
8+
async def get_dependencies(request: Request) -> dict[str, t.Any] | None:
9+
route = request["route"]
10+
11+
if hasattr(route, "dependant"):
12+
try:
13+
solved_result = await solve_dependencies(
14+
request=request,
15+
dependant=route.dependant,
16+
dependency_overrides_provider=route.dependency_overrides_provider,
17+
async_exit_stack=AsyncExitStack(),
18+
)
19+
except HTTPException:
20+
pass
21+
else:
22+
return solved_result[0]
23+
return None

debug_toolbar/panels/sqlalchemy.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from __future__ import annotations
22

33
import typing as t
4-
from contextlib import AsyncExitStack
54
from time import perf_counter
65

7-
from fastapi import HTTPException, Request, Response
8-
from fastapi.dependencies.utils import solve_dependencies
6+
from fastapi import Request, Response
97
from sqlalchemy import event
108
from sqlalchemy.engine import Connection, Engine, ExecutionContext
119
from sqlalchemy.exc import UnboundExecutionError
1210
from sqlalchemy.ext.asyncio import AsyncSession
1311
from sqlalchemy.orm import Session
1412

13+
from debug_toolbar.dependencies import get_dependencies
1514
from debug_toolbar.panels.sql import SQLPanel
1615

1716

@@ -50,32 +49,22 @@ def add_bind(self, bind: Connection | Engine):
5049
else:
5150
self.engines.add(bind)
5251

53-
async def add_engines(self, request: Request): # noqa: C901
54-
route = request["route"]
55-
56-
if hasattr(route, "dependant"):
57-
try:
58-
solved_result = await solve_dependencies(
59-
request=request,
60-
dependant=route.dependant,
61-
dependency_overrides_provider=route.dependency_overrides_provider,
62-
async_exit_stack=AsyncExitStack(),
63-
)
64-
except HTTPException:
65-
pass
66-
else:
67-
for value in solved_result[0].values():
68-
if isinstance(value, AsyncSession):
69-
value = value.sync_session
70-
71-
if isinstance(value, Session):
72-
try:
73-
bind = value.get_bind()
74-
except UnboundExecutionError:
75-
for bind in value._Session__binds.values(): # type: ignore[attr-defined]
76-
self.add_bind(bind)
77-
else:
52+
async def add_engines(self, request: Request):
53+
dependencies = await get_dependencies(request)
54+
55+
if dependencies is not None:
56+
for value in dependencies.values():
57+
if isinstance(value, AsyncSession):
58+
value = value.sync_session
59+
60+
if isinstance(value, Session):
61+
try:
62+
bind = value.get_bind()
63+
except UnboundExecutionError:
64+
for bind in value._Session__binds.values(): # type: ignore[attr-defined]
7865
self.add_bind(bind)
66+
else:
67+
self.add_bind(bind)
7968

8069
async def process_request(self, request: Request) -> Response:
8170
await self.add_engines(request)

0 commit comments

Comments
 (0)