Skip to content

Commit f346921

Browse files
bryanforbesilevkivskyi
authored andcommitted
Fix MetaData.bind typing (#43)
1 parent 1577a20 commit f346921

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

sqlalchemy-stubs/engine/__init__.pyi

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any
12
from . import default as default
23

34
from .interfaces import (
@@ -7,7 +8,7 @@ from .interfaces import (
78
ExecutionContext as ExecutionContext,
89
ExceptionContext as ExceptionContext,
910
Compiled as Compiled,
10-
TypeCompiler as TypeCompiler
11+
TypeCompiler as TypeCompiler,
1112
)
1213

1314
from .base import (
@@ -29,5 +30,5 @@ from .result import (
2930
RowProxy as RowProxy,
3031
)
3132

32-
def create_engine(*args, **kwargs): ...
33-
def engine_from_config(configuration, prefix: str = ..., **kwargs): ...
33+
def create_engine(*args: Any, **kwargs: Any) -> Engine: ...
34+
def engine_from_config(configuration: Any, prefix: str = ..., **kwargs: Any) -> Engine: ...

sqlalchemy-stubs/sql/schema.pyi

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from .selectable import TableClause as TableClause
99
from .type_api import TypeEngine
1010
from .. import util
1111
from ..engine import Engine, Connection, Connectable
12+
from ..engine.url import URL
1213
from .compiler import DDLCompiler
1314
import threading
1415

@@ -298,18 +299,25 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
298299

299300
DEFAULT_NAMING_CONVENTION: util.immutabledict[str, str] = ...
300301

302+
class _MetaDataBind:
303+
@overload
304+
def __get__(self, instance: None, owner: Any) -> None: ...
305+
@overload
306+
def __get__(self, instance: MetaData, owner: Any) -> Optional[Union[Engine, Connection]]: ...
307+
def __set__(self, instance: Any, value: Optional[Union[Engine, Connection, str, URL]]) -> None: ...
308+
301309
class MetaData(SchemaItem):
302310
__visit_name__: str = ...
303311
tables: util.immutabledict[str, Table] = ...
304312
schema: Optional[str] = ...
305313
naming_convention: Mapping[Union[str, Index, Constraint], str] = ...
306314
info: Optional[Mapping[str, Any]] = ...
315+
bind: _MetaDataBind = ...
307316
def __init__(self, bind: Optional[Union[Engine, Connection]] = ..., reflect: bool = ..., schema: Optional[str] = ...,
308317
quote_schema: Optional[bool] = ..., naming_convention: Mapping[Union[str, Index, Constraint], str] = ...,
309318
info: Optional[Mapping[str, Any]] = ...) -> None: ...
310319
def __contains__(self, table_or_key: Union[str, Table]) -> bool: ...
311320
def is_bound(self) -> bool: ...
312-
def bind(self) -> Optional[Union[Engine, Connection]]: ...
313321
def clear(self) -> None: ...
314322
def remove(self, table: Table) -> None: ...
315323
@property
@@ -326,7 +334,7 @@ class MetaData(SchemaItem):
326334
class ThreadLocalMetaData(MetaData):
327335
__visit_name__: str = ...
328336
context: threading.local = ...
337+
bind: _MetaDataBind = ...
329338
def __init__(self) -> None: ...
330-
def bind(self) -> Optional[Union[Engine, Connection]]: ...
331339
def is_bound(self) -> bool: ...
332340
def dispose(self) -> None: ...
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[case testSchemaMetaData]
2+
from sqlalchemy import MetaData, create_engine
3+
from sqlalchemy.engine import Connection
4+
from sqlalchemy.engine.url import make_url
5+
6+
m = MetaData()
7+
e = create_engine('postgresql://foo')
8+
c = Connection(e)
9+
m.bind = 'postgresql://foo'
10+
m.bind = make_url('postgresql://foo')
11+
reveal_type(m.bind) # E: Revealed type is 'Union[sqlalchemy.engine.base.Engine, sqlalchemy.engine.base.Connection, None]'
12+
m.bind = e
13+
reveal_type(m.bind) # E: Revealed type is 'sqlalchemy.engine.base.Engine'
14+
m.bind = c
15+
reveal_type(m.bind) # E: Revealed type is 'sqlalchemy.engine.base.Connection'
16+
[out]

test/testsql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class SQLDataSuite(DataSuite):
2323
files = ['sqlalchemy-basics.test',
2424
'sqlalchemy-sql-elements.test',
2525
'sqlalchemy-sql-sqltypes.test',
26-
'sqlalchemy-sql-selectable.test']
26+
'sqlalchemy-sql-selectable.test',
27+
'sqlalchemy-sql-schema.test']
2728
data_prefix = test_data_prefix
2829

2930
def run_case(self, testcase: DataDrivenTestCase) -> None:

0 commit comments

Comments
 (0)