Skip to content

Commit f4b68dc

Browse files
authored
fix issue with _AsyncSessionProtocol (#157)
1 parent 2606071 commit f4b68dc

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

sqlalchemy-stubs/ext/asyncio/session.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ class _AsyncSessionProtocol(
134134
async def close_all(cls) -> None: ... # NOTE: Deprecated.
135135

136136
class _AsyncSessionTypingCommon(
137-
_SessionNoIoTypingCommon, _SessionClassMethodNoIoTypingCommon
137+
_SessionNoIoTypingCommon[Union[AsyncConnection, AsyncEngine]],
138+
_SessionClassMethodNoIoTypingCommon,
138139
):
139140
bind: Any = ...
140141
def begin(self, **kw: Any) -> AsyncSessionTransaction: ...

sqlalchemy-stubs/orm/session.pyi

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class SessionTransaction:
290290
def __enter__(self: _TSessionTransaction) -> _TSessionTransaction: ...
291291
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: ...
292292

293-
class _SessionNoIoTypingCommon:
293+
class _SessionNoIoTypingCommon(Generic[_T]):
294294
@property
295295
def dirty(self) -> util.IdentitySet[Any]: ...
296296
@property
@@ -321,15 +321,17 @@ class _SessionNoIoTypingCommon:
321321
self,
322322
mapper: Optional[Any] = ...,
323323
clause: Optional[ClauseElement] = ...,
324-
bind: Optional[Union[Connection, Engine]] = ...,
324+
bind: Optional[_T] = ...,
325325
_sa_skip_events: Optional[Any] = ...,
326326
_sa_skip_for_implicit_returning: bool = ...,
327-
) -> Union[Connection, Engine]: ...
327+
) -> _T: ...
328328
def is_modified(
329329
self, instance: Any, include_collections: bool = ...
330330
) -> bool: ...
331331

332-
class _SessionTypingCommon(_SessionNoIoTypingCommon):
332+
class _SessionTypingCommon(
333+
_SessionNoIoTypingCommon[Union[Connection, Engine]]
334+
):
333335
bind: Optional[Union[Connection, Engine]]
334336
autocommit: bool
335337
def begin(

test/files/async_stuff.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from sqlalchemy.ext.asyncio import AsyncSession
2+
from sqlalchemy.ext.asyncio import create_async_engine
3+
from sqlalchemy.orm import sessionmaker
4+
from sqlalchemy.orm import Session
5+
6+
engine = create_async_engine(...)
7+
async_session = sessionmaker(engine, class_=AsyncSession)

0 commit comments

Comments
 (0)