Skip to content

Commit fd04f8c

Browse files
committed
Get caller frame at decoration-time
Here we are more careful about which caller's locals we use to resolve forward type references. We want the callers locals at decoration-time — not at decorator-construction time. Consider: ```py frozen_dataclass = marshmallow_dataclass.dataclass(frozen=True) def f(): @custom_dataclass class A: b: "B" @custom_dataclass class B: x: int ``` The locals we want in this case are the one from where the custom_dataclass decorator is called, not from where marshmallow_dataclass.dataclass is called.
1 parent 6416dc9 commit fd04f8c

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,51 @@ class User:
8282
MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
8383

8484

85+
def _maybe_get_callers_frame(
86+
cls: type, stacklevel: int = 1
87+
) -> Optional[types.FrameType]:
88+
"""Return the caller's frame, but only if it will help resolve forward type references.
89+
90+
We sometimes need the caller's frame to get access to the caller's
91+
local namespace in order to be able to resolve forward type
92+
references in dataclasses.
93+
94+
Notes
95+
-----
96+
97+
If the caller's locals are the same as the dataclass' module
98+
globals — this is the case for the common case of dataclasses
99+
defined at the module top-level — we don't need the locals.
100+
(Typing.get_type_hints() knows how to check the class module
101+
globals on its own.)
102+
103+
In that case, we don't need the caller's frame. Not holding a
104+
reference to the frame in our our lazy ``.Scheme`` class attribute
105+
is a significant win, memory-wise.
106+
107+
"""
108+
try:
109+
frame = inspect.currentframe()
110+
for _ in range(stacklevel + 1):
111+
if frame is None:
112+
return None
113+
frame = frame.f_back
114+
115+
if frame is None:
116+
return None
117+
118+
globalns = getattr(sys.modules.get(cls.__module__), "__dict__", None)
119+
if frame.f_locals is globalns:
120+
# Locals are the globals
121+
return None
122+
123+
return frame
124+
125+
finally:
126+
# Paranoia, per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
127+
del frame
128+
129+
85130
@overload
86131
def dataclass(
87132
_cls: Type[_U],
@@ -124,6 +169,7 @@ def dataclass(
124169
frozen: bool = False,
125170
base_schema: Optional[Type[marshmallow.Schema]] = None,
126171
cls_frame: Optional[types.FrameType] = None,
172+
stacklevel: int = 1,
127173
) -> Union[Type[_U], Callable[[Type[_U]], Type[_U]]]:
128174
"""
129175
This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`.
@@ -150,19 +196,18 @@ def dataclass(
150196
>>> Point.Schema().load({'x':0, 'y':0}) # This line can be statically type checked
151197
Point(x=0.0, y=0.0)
152198
"""
153-
# dataclass's typing doesn't expect it to be called as a function, so ignore type check
154-
dc = dataclasses.dataclass( # type: ignore
155-
_cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
199+
dc = dataclasses.dataclass(
200+
repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
156201
)
157-
if not cls_frame:
158-
current_frame = inspect.currentframe()
159-
if current_frame:
160-
cls_frame = current_frame.f_back
161-
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
162-
del current_frame
202+
203+
def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
204+
return add_schema(
205+
dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1
206+
)
207+
163208
if _cls is None:
164-
return lambda cls: add_schema(dc(cls), base_schema, cls_frame=cls_frame)
165-
return add_schema(dc, base_schema, cls_frame=cls_frame)
209+
return decorator
210+
return decorator(_cls, stacklevel=stacklevel + 1)
166211

167212

168213
@overload
@@ -182,11 +227,12 @@ def add_schema(
182227
_cls: Type[_U],
183228
base_schema: Optional[Type[marshmallow.Schema]] = None,
184229
cls_frame: Optional[types.FrameType] = None,
230+
stacklevel: int = 1,
185231
) -> Type[_U]:
186232
...
187233

188234

189-
def add_schema(_cls=None, base_schema=None, cls_frame=None):
235+
def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
190236
"""
191237
This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass.
192238
It uses :func:`class_schema` internally.
@@ -208,31 +254,23 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None):
208254
Artist(names=('Martin', 'Ramirez'))
209255
"""
210256

211-
def decorator(clazz: Type[_U]) -> Type[_U]:
212-
cls_frame_ = cls_frame
257+
def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
213258
if cls_frame is not None:
214-
cls_globals = getattr(sys.modules.get(clazz.__module__), "__dict__", None)
215-
if cls_frame.f_locals is cls_globals:
216-
# Memory optimization:
217-
# If the caller's locals are the same as the class
218-
# module globals, we don't need the locals. (This is
219-
# typically the case for dataclasses defined at the
220-
# module top-level.) (Typing.get_type_hints() knows
221-
# how to check the class module globals on its own.)
222-
# Not holding a reference to the frame in our our lazy
223-
# class attribute which is a significant win,
224-
# memory-wise.
225-
cls_frame_ = None
259+
frame = cls_frame
260+
else:
261+
frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel)
226262

227263
# noinspection PyTypeHints
228264
clazz.Schema = lazy_class_attribute( # type: ignore
229-
partial(class_schema, clazz, base_schema, cls_frame_),
265+
partial(class_schema, clazz, base_schema, frame),
230266
"Schema",
231267
clazz.__name__,
232268
)
233269
return clazz
234270

235-
return decorator(_cls) if _cls else decorator
271+
if _cls is None:
272+
return decorator
273+
return decorator(_cls, stacklevel=stacklevel + 1)
236274

237275

238276
def class_schema(
@@ -361,11 +399,7 @@ def class_schema(
361399
if not dataclasses.is_dataclass(clazz):
362400
clazz = dataclasses.dataclass(clazz)
363401
if not clazz_frame:
364-
current_frame = inspect.currentframe()
365-
if current_frame:
366-
clazz_frame = current_frame.f_back
367-
# Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
368-
del current_frame
402+
clazz_frame = _maybe_get_callers_frame(clazz)
369403

370404
with _SchemaContext(clazz_frame):
371405
return _internal_class_schema(clazz, base_schema)

tests/test_forward_references.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,19 @@ class B:
133133
B.Schema().load(dict(a=dict(c=1)))
134134
# marshmallow.exceptions.ValidationError:
135135
# {'a': {'d': ['Missing data for required field.'], 'c': ['Unknown field.']}}
136+
137+
def test_locals_from_decoration_ns(self):
138+
# Test that locals are picked-up at decoration-time rather
139+
# than when the decorator is constructed.
140+
@frozen_dataclass
141+
class A:
142+
b: "B"
143+
144+
@frozen_dataclass
145+
class B:
146+
x: int
147+
148+
assert A.Schema().load({"b": {"x": 42}}) == A(b=B(x=42))
149+
150+
151+
frozen_dataclass = dataclass(frozen=True)

0 commit comments

Comments
 (0)