Skip to content

Commit 43108bf

Browse files
committed
Minimal change to allow attribute_keyed_dict + test
1 parent 2bfcad1 commit 43108bf

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

sqlmodel/_compat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def get_relationship_to(
178178
# If a list, then also get the real field
179179
elif origin is list:
180180
use_annotation = get_args(annotation)[0]
181+
# If a dict, then use the value type
182+
elif origin is dict:
183+
use_annotation = get_args(annotation)[1]
181184

182185
return get_relationship_to(
183186
name=name, rel_info=rel_info, annotation=use_annotation

tests/test_attribute_keyed_dict.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from enum import StrEnum
2+
3+
from sqlalchemy.orm.collections import attribute_keyed_dict
4+
from sqlmodel import Field, Index, Relationship, Session, SQLModel, create_engine
5+
6+
7+
def test_attribute_keyed_dict_works(clear_sqlmodel):
8+
class Color(StrEnum):
9+
Orange = "Orange"
10+
Blue = "Blue"
11+
12+
class Child(SQLModel, table=True):
13+
__tablename__ = "children"
14+
__table_args__ = (
15+
Index("ix_children_parent_id_color", "parent_id", "color", unique=True),
16+
)
17+
18+
id: int | None = Field(primary_key=True, default=None)
19+
parent_id: int = Field(foreign_key="parents.id")
20+
color: Color
21+
value: int
22+
23+
class Parent(SQLModel, table=True):
24+
__tablename__ = "parents"
25+
26+
id: int | None = Field(primary_key=True, default=None)
27+
children_by_color: dict[Color, Child] = Relationship(
28+
sa_relationship_kwargs={"collection_class": attribute_keyed_dict("color")}
29+
)
30+
31+
engine = create_engine("sqlite://")
32+
SQLModel.metadata.create_all(engine)
33+
with Session(engine) as session:
34+
parent = Parent()
35+
session.add(parent)
36+
session.commit()
37+
session.refresh(parent)
38+
session.add(Child(parent_id=parent.id, color=Color.Orange, value=1))
39+
session.add(Child(parent_id=parent.id, color=Color.Blue, value=2))
40+
session.commit()
41+
session.refresh(parent)
42+
assert parent.children_by_color[Color.Orange].parent_id == parent.id
43+
assert parent.children_by_color[Color.Orange].color == Color.Orange
44+
assert parent.children_by_color[Color.Orange].value == 1
45+
assert parent.children_by_color[Color.Blue].parent_id == parent.id
46+
assert parent.children_by_color[Color.Blue].color == Color.Blue
47+
assert parent.children_by_color[Color.Blue].value == 2

0 commit comments

Comments
 (0)