Skip to content

Commit 849eadd

Browse files
committed
Support custom sqltype with MutablePydanticBaseModel
1 parent b6c7bfa commit 849eadd

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"cSpell.words": [
3-
"autouse"
3+
"autouse",
4+
"sqltype"
45
]
56
}

sqlalchemy_nested_mutable/mutable.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,20 @@ class PydanticType(sa.types.TypeDecorator, TypeEngine[_P]):
4545
"""
4646
Inspired by https://gist.github.com/imankulov/4051b7805ad737ace7d8de3d3f934d6b
4747
"""
48+
cache_ok = True
4849
impl = sa.types.JSON
4950

50-
def __init__(self, pydantic_type: type[_P]):
51+
def __init__(self, pydantic_type: type[_P], sqltype: TypeEngine[_T] = None):
5152
super().__init__()
5253
self.pydantic_type = pydantic_type
54+
self.sqltype = sqltype
5355

5456
def load_dialect_impl(self, dialect):
5557
from sqlalchemy.dialects.postgresql import JSONB
5658

59+
if self.sqltype is not None:
60+
return dialect.type_descriptor(self.sqltype)
61+
5762
if dialect.name == "postgresql":
5863
return dialect.type_descriptor(JSONB())
5964
return dialect.type_descriptor(sa.JSON())
@@ -79,8 +84,8 @@ def dict(self, *args, **kwargs):
7984
return res
8085

8186
@classmethod
82-
def as_mutable(cls, /) -> TypeEngine[Self]:
83-
return super().as_mutable(PydanticType(cls))
87+
def as_mutable(cls, sqltype: TypeEngine[_T] = None) -> TypeEngine[Self]:
88+
return super().as_mutable(PydanticType(cls, sqltype))
8489
elif not TYPE_CHECKING:
8590
class PydanticType:
8691
def __new__(cls, *a, **k):

tests/test_custom_sqltype.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Optional, List
2+
3+
import pytest
4+
from sqlalchemy.dialects.postgresql import JSON, JSONB
5+
from sqlalchemy_nested_mutable._compat import pydantic
6+
import sqlalchemy as sa
7+
from sqlalchemy.orm import (
8+
DeclarativeBase,
9+
Mapped,
10+
mapped_column,
11+
)
12+
13+
from sqlalchemy_nested_mutable import MutablePydanticBaseModel
14+
15+
16+
class Base(DeclarativeBase):
17+
pass
18+
19+
20+
class Addresses(MutablePydanticBaseModel):
21+
class AddressItem(pydantic.BaseModel):
22+
street: str
23+
city: str
24+
area: Optional[str]
25+
26+
work: List[AddressItem] = []
27+
home: List[AddressItem] = []
28+
29+
30+
class User(Base):
31+
__tablename__ = "user_account"
32+
33+
id: Mapped[int] = mapped_column(primary_key=True)
34+
name: Mapped[str] = mapped_column(sa.String(30))
35+
addresses_default: Mapped[Optional[Addresses]] = mapped_column(Addresses.as_mutable())
36+
addresses_json: Mapped[Optional[Addresses]] = mapped_column(Addresses.as_mutable(JSON))
37+
addresses_jsonb: Mapped[Optional[Addresses]] = mapped_column(Addresses.as_mutable(JSONB))
38+
39+
40+
@pytest.fixture(scope="module", autouse=True)
41+
def _with_tables(session):
42+
Base.metadata.create_all(session.bind)
43+
yield
44+
session.execute(sa.text("""
45+
DROP TABLE user_account CASCADE;
46+
"""))
47+
session.commit()
48+
49+
50+
def test_mutable_pydantic_type(session):
51+
session.add(User(name="foo"))
52+
session.commit()
53+
assert session.scalar(sa.select(sa.func.pg_typeof(User.addresses_default))) == "jsonb"
54+
assert session.scalar(sa.select(sa.func.pg_typeof(User.addresses_json))) == "json"
55+
assert session.scalar(sa.select(sa.func.pg_typeof(User.addresses_jsonb))) == "jsonb"

0 commit comments

Comments
 (0)