Skip to content

Commit 3f11b80

Browse files
committed
Make dict get_relationship_to more robust
1 parent c22321e commit 3f11b80

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

sqlmodel/_compat.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,16 @@ def get_relationship_to(
178178
# If a list, then also get the real field
179179
elif origin is list:
180180
use_annotation = get_args(annotation)[0]
181-
# If a dict, then use the value type
182-
elif origin is dict:
183-
use_annotation = get_args(annotation)[1]
181+
# If a dict or Mapping, then use the value (second) type argument
182+
elif origin is dict or origin is Mapping:
183+
args = get_args(annotation)
184+
if len(args) >= 2:
185+
use_annotation = args[1]
186+
else:
187+
raise ValueError(
188+
f"Dict/Mapping relationship field '{name}' must have both "
189+
"key and value type arguments (e.g., dict[str, Model])"
190+
)
184191

185192
return get_relationship_to(
186193
name=name, rel_info=rel_info, annotation=use_annotation

tests/test_attribute_keyed_dict.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import re
12
from enum import Enum
23
from typing import Dict, Optional
34

5+
import pytest
46
from sqlalchemy.orm.collections import attribute_keyed_dict
57
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine
68

@@ -43,3 +45,46 @@ class Parent(SQLModel, table=True):
4345
assert parent.children_by_color[Color.Blue].parent_id == parent.id
4446
assert parent.children_by_color[Color.Blue].color == Color.Blue
4547
assert parent.children_by_color[Color.Blue].value == 2
48+
49+
50+
def test_dict_relationship_throws_on_missing_annotation_arg(clear_sqlmodel):
51+
class Color(str, Enum):
52+
Orange = "Orange"
53+
Blue = "Blue"
54+
55+
class Child(SQLModel, table=True):
56+
__tablename__ = "children"
57+
58+
id: Optional[int] = Field(primary_key=True, default=None)
59+
parent_id: int = Field(foreign_key="parents.id")
60+
color: Color
61+
value: int
62+
63+
error_msg_re = re.escape(
64+
"Dict/Mapping relationship field 'children_by_color' must have both key and value type arguments (e.g., dict[str, Model])"
65+
)
66+
# No type args
67+
with pytest.raises(ValueError, match=error_msg_re):
68+
69+
class Parent(SQLModel, table=True):
70+
__tablename__ = "parents"
71+
72+
id: Optional[int] = Field(primary_key=True, default=None)
73+
children_by_color: dict[()] = Relationship(
74+
sa_relationship_kwargs={
75+
"collection_class": attribute_keyed_dict("color")
76+
}
77+
)
78+
79+
# One type arg
80+
with pytest.raises(ValueError, match=error_msg_re):
81+
82+
class Parent(SQLModel, table=True):
83+
__tablename__ = "parents"
84+
85+
id: Optional[int] = Field(primary_key=True, default=None)
86+
children_by_color: dict[Color] = Relationship(
87+
sa_relationship_kwargs={
88+
"collection_class": attribute_keyed_dict("color")
89+
}
90+
)

0 commit comments

Comments
 (0)