Skip to content

Commit daf6637

Browse files
authored
A bunch or minor fixes and temporary things (#51)
Fixes #50 The goal is mostly to simplify the sync with our internal code. There are two changes that are going to be reverted soon: * Type engine parameter for `String` should actually be `Text`, it is set to `str` temporarily for short transition period. It will be reverted back to `Text` before the release. * The two lenient overloads in `Column.__init__()` are temporarily added to suppress some errors. They will be also removed before the release.
1 parent c0d5592 commit daf6637

File tree

7 files changed

+118
-13
lines changed

7 files changed

+118
-13
lines changed

sqlalchemy-stubs/engine/interfaces.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ from .result import ResultProxy
44
from ..sql.compiler import Compiled as Compiled, TypeCompiler as TypeCompiler
55

66
class Dialect(object):
7+
@property
8+
def name(self) -> str: ...
79
def create_connect_args(self, url): ...
810
@classmethod
911
def type_descriptor(cls, typeobj): ...

sqlalchemy-stubs/sql/schema.pyi

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from .. import util
1111
from ..engine import Engine, Connection, Connectable
1212
from ..engine.url import URL
1313
from .compiler import DDLCompiler
14+
from .expression import FunctionElement
1415
import threading
1516

1617
_T = TypeVar('_T')
@@ -90,25 +91,38 @@ class Column(SchemaItem, ColumnClause[_T]):
9091
def __init__(self, name: str, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ...,
9192
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
9293
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
93-
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
94+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
9495
system: bool = ..., comment: str = ...) -> None: ...
9596
@overload
9697
def __init__(self, type_: Type[TypeEngine[_T]], *args: Any, autoincrement: Union[bool, str] = ...,
9798
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
9899
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
99-
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
100+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
100101
system: bool = ..., comment: str = ...) -> None: ...
101102
@overload
102103
def __init__(self, name: str, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ...,
103104
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
104105
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
105-
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
106+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
106107
system: bool = ..., comment: str = ...) -> None: ...
107108
@overload
108109
def __init__(self, type_: TypeEngine[_T], *args: Any, autoincrement: Union[bool, str] = ...,
109110
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
110111
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
111-
server_onupdate: FetchedValue = ..., quote: Optional[bool] = ..., unique: bool = ...,
112+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
113+
system: bool = ..., comment: str = ...) -> None: ...
114+
# The two overloads below exist to make annotation more like a cast. This is a temporary measure.
115+
@overload
116+
def __init__(self, name: str, type_: Any, *args: Any, autoincrement: Union[bool, str] = ...,
117+
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
118+
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
119+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
120+
system: bool = ..., comment: str = ...) -> None: ...
121+
@overload
122+
def __init__(self, type_: Any, *args: Any, autoincrement: Union[bool, str] = ...,
123+
default: Any = ..., doc: str = ..., key: str = ..., index: bool = ..., info: Mapping[str, Any] = ...,
124+
nullable: bool = ..., onupdate: Any = ..., primary_key: bool = ..., server_default: Any = ...,
125+
server_onupdate: Union[FetchedValue, FunctionElement] = ..., quote: Optional[bool] = ..., unique: bool = ...,
112126
system: bool = ..., comment: str = ...) -> None: ...
113127
def references(self, column: Column[Any]) -> bool: ...
114128
def append_foreign_key(self, fk: ForeignKey) -> None: ...

sqlalchemy-stubs/sql/sqltypes.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Indexable(object):
2222

2323
# Docs say that String is unicode when DBAPI supports it
2424
# but it should be all major DBAPIs now.
25-
class String(Concatenable, TypeEngine[typing_Text]):
25+
class String(Concatenable, TypeEngine[str]): # XXX: should be typing_Text
2626
__visit_name__: str = ...
2727
length: Optional[int] = ...
2828
collation: Optional[str] = ...
@@ -39,7 +39,7 @@ class String(Concatenable, TypeEngine[typing_Text]):
3939
def bind_processor(self, dialect: Dialect) -> Optional[Callable[[str], str]]: ...
4040
def result_processor(self, dialect: Dialect, coltype: Any) -> Optional[Callable[[Optional[Any]], Optional[str]]]: ...
4141
@property
42-
def python_type(self) -> Type[typing_Text]: ...
42+
def python_type(self) -> Type[str]: ...
4343
def get_dbapi_type(self, dbapi: Any) -> Any: ...
4444

4545
class Text(String):

sqlalchemy-stubs/sql/type_api.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload
1+
from typing import Any, Optional, Union, TypeVar, Generic, Type, Callable, ClassVar, Tuple, Mapping, overload, Text as typing_Text
22
from .. import util
33
from .visitors import Visitable as Visitable, VisitableType as VisitableType
44
from .base import SchemaEventTarget as SchemaEventTarget
@@ -91,7 +91,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine[_T]):
9191
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ...
9292
def __getattr__(self, key: str) -> Any: ...
9393
def process_literal_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ...
94-
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[str]: ...
94+
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]: ...
9595
def process_result_value(self, value: Optional[Any], dialect: Dialect) -> Optional[_T]: ...
9696
def literal_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ...
9797
def bind_processor(self, dialect: Dialect) -> Callable[[Optional[_T]], Optional[str]]: ...

sqlmypy.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from mypy.plugins.common import add_method
33
from mypy.nodes import(
44
NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF,
5-
Argument, Var, ARG_STAR2
5+
Argument, Var, ARG_STAR2, MDEF
66
)
77
from mypy.types import (
88
UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType
@@ -67,11 +67,23 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
6767
sym = self.lookup_fully_qualified(fullname)
6868
if sym and isinstance(sym.node, TypeInfo):
6969
if is_declarative(sym.node):
70-
return add_init_hook
70+
return add_model_init_hook
7171
return None
7272

7373

74-
def add_init_hook(ctx: ClassDefContext) -> None:
74+
def add_var_to_class(name: str, typ: Type, info: TypeInfo) -> None:
75+
"""Add a variable with given name and type to the symbol table of a class.
76+
77+
This also takes care about setting necessary attributes on the variable node.
78+
"""
79+
var = Var(name)
80+
var.info = info
81+
var._fullname = info.fullname() + '.' + name
82+
var.type = typ
83+
info.names[name] = SymbolTableNode(MDEF, var)
84+
85+
86+
def add_model_init_hook(ctx: ClassDefContext) -> None:
7587
"""Add a dummy __init__() to a model and record it is generated.
7688
7789
Instantiation will be checked more precisely when we inferred types
@@ -86,6 +98,26 @@ def add_init_hook(ctx: ClassDefContext) -> None:
8698
add_method(ctx, '__init__', [kw_arg], NoneTyp())
8799
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True
88100

101+
# Also add a selection of auto-generated attributes.
102+
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table')
103+
if sym:
104+
assert isinstance(sym.node, TypeInfo)
105+
typ = Instance(sym.node, []) # type: Type
106+
else:
107+
typ = AnyType(TypeOfAny.special_form)
108+
add_var_to_class('__table__', typ, ctx.cls.info)
109+
110+
111+
def add_metadata_var(ctx: ClassDefContext, info: TypeInfo) -> None:
112+
"""Add .metadata attribute to a declarative base."""
113+
sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.MetaData')
114+
if sym:
115+
assert isinstance(sym.node, TypeInfo)
116+
typ = Instance(sym.node, []) # type: Type
117+
else:
118+
typ = AnyType(TypeOfAny.special_form)
119+
add_var_to_class('metadata', typ, info)
120+
89121

90122
def decl_deco_hook(ctx: ClassDefContext) -> None:
91123
"""Support declaring base class as declarative with a decorator.
@@ -98,6 +130,7 @@ class Base:
98130
...
99131
"""
100132
set_declarative(ctx.cls.info)
133+
add_metadata_var(ctx, ctx.cls.info)
101134

102135

103136
def decl_info_hook(ctx):
@@ -119,6 +152,9 @@ def decl_info_hook(ctx):
119152
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
120153
set_declarative(info)
121154

155+
# TODO: check what else is added.
156+
add_metadata_var(ctx, info)
157+
122158

123159
def model_hook(ctx: FunctionContext) -> Type:
124160
"""More precise model instantiation check.
@@ -146,6 +182,10 @@ def model_hook(ctx: FunctionContext) -> Type:
146182
assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__
147183
assert len(ctx.arg_types) == 1
148184
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
185+
if actual_name is None:
186+
# We can't check kwargs reliably.
187+
# TODO: support TypedDict?
188+
continue
149189
if actual_name not in expected_types:
150190
ctx.api.fail('Unexpected column "{}" for model "{}"'.format(actual_name, model.name()),
151191
ctx.context)

test/test-data/sqlalchemy-basics.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class User(Base):
102102

103103
user = User()
104104
reveal_type(user.id) # E: Revealed type is 'builtins.int*'
105-
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.unicode*, None]]'
105+
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]'
106106
[out]
107107

108108
[case testColumnFieldsInferredInstance_python2]
@@ -118,5 +118,5 @@ class User(Base):
118118

119119
user = User()
120120
reveal_type(user.id) # E: Revealed type is 'builtins.int*'
121-
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]'
121+
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.str*]'
122122
[out]

test/test-data/sqlalchemy-plugin-features.test

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,52 @@ class User(Base):
170170
user: User
171171
reveal_type(User.name) # E: Revealed type is 'sqlalchemy.sql.schema.Column[Union[builtins.str*, None]]'
172172
[out]
173+
174+
[case testAddedAttributesDeclared]
175+
from sqlalchemy.ext.declarative import declarative_base
176+
from sqlalchemy import Column, Integer, String
177+
178+
Base = declarative_base()
179+
180+
class User(Base):
181+
__tablename__ = 'users'
182+
id = Column(Integer(), primary_key=True)
183+
name = Column(String(), default='John Doe', nullable=True)
184+
185+
user: User
186+
reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData'
187+
reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table'
188+
[out]
189+
190+
[case testAddedAttributedDecorated]
191+
from sqlalchemy.ext.declarative import as_declarative
192+
from sqlalchemy import Column, Integer, String
193+
194+
@as_declarative()
195+
class Base:
196+
...
197+
198+
class User(Base):
199+
__tablename__ = 'users'
200+
id = Column(Integer(), primary_key=True)
201+
name = Column(String(), default='John Doe', nullable=True)
202+
203+
user: User
204+
reveal_type(User.metadata) # E: Revealed type is 'sqlalchemy.sql.schema.MetaData'
205+
reveal_type(User.__table__) # E: Revealed type is 'sqlalchemy.sql.schema.Table'
206+
[out]
207+
208+
[case testKwArgsModelOK]
209+
from sqlalchemy import Column, Integer, String
210+
from sqlalchemy.ext.declarative import declarative_base
211+
212+
Base = declarative_base()
213+
214+
class User(Base):
215+
__tablename__ = 'users'
216+
id = Column(Integer, primary_key=True)
217+
name = Column(String)
218+
219+
record = {'name': 'John Doe'}
220+
User(**record) # OK
221+
[out]

0 commit comments

Comments
 (0)