|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import typing as t |
4 | | -from contextlib import AsyncExitStack |
5 | 4 | from time import perf_counter |
6 | 5 |
|
7 | | -from fastapi import HTTPException, Request, Response |
8 | | -from fastapi.dependencies.utils import solve_dependencies |
| 6 | +from fastapi import Request, Response |
9 | 7 | from sqlalchemy import event |
10 | 8 | from sqlalchemy.engine import Connection, Engine, ExecutionContext |
11 | 9 | from sqlalchemy.exc import UnboundExecutionError |
12 | 10 | from sqlalchemy.ext.asyncio import AsyncSession |
13 | 11 | from sqlalchemy.orm import Session |
14 | 12 |
|
| 13 | +from debug_toolbar.dependencies import get_dependencies |
15 | 14 | from debug_toolbar.panels.sql import SQLPanel |
16 | 15 |
|
17 | 16 |
|
@@ -50,32 +49,22 @@ def add_bind(self, bind: Connection | Engine): |
50 | 49 | else: |
51 | 50 | self.engines.add(bind) |
52 | 51 |
|
53 | | - async def add_engines(self, request: Request): # noqa: C901 |
54 | | - route = request["route"] |
55 | | - |
56 | | - if hasattr(route, "dependant"): |
57 | | - try: |
58 | | - solved_result = await solve_dependencies( |
59 | | - request=request, |
60 | | - dependant=route.dependant, |
61 | | - dependency_overrides_provider=route.dependency_overrides_provider, |
62 | | - async_exit_stack=AsyncExitStack(), |
63 | | - ) |
64 | | - except HTTPException: |
65 | | - pass |
66 | | - else: |
67 | | - for value in solved_result[0].values(): |
68 | | - if isinstance(value, AsyncSession): |
69 | | - value = value.sync_session |
70 | | - |
71 | | - if isinstance(value, Session): |
72 | | - try: |
73 | | - bind = value.get_bind() |
74 | | - except UnboundExecutionError: |
75 | | - for bind in value._Session__binds.values(): # type: ignore[attr-defined] |
76 | | - self.add_bind(bind) |
77 | | - else: |
| 52 | + async def add_engines(self, request: Request): |
| 53 | + dependencies = await get_dependencies(request) |
| 54 | + |
| 55 | + if dependencies is not None: |
| 56 | + for value in dependencies.values(): |
| 57 | + if isinstance(value, AsyncSession): |
| 58 | + value = value.sync_session |
| 59 | + |
| 60 | + if isinstance(value, Session): |
| 61 | + try: |
| 62 | + bind = value.get_bind() |
| 63 | + except UnboundExecutionError: |
| 64 | + for bind in value._Session__binds.values(): # type: ignore[attr-defined] |
78 | 65 | self.add_bind(bind) |
| 66 | + else: |
| 67 | + self.add_bind(bind) |
79 | 68 |
|
80 | 69 | async def process_request(self, request: Request) -> Response: |
81 | 70 | await self.add_engines(request) |
|
0 commit comments