Skip to content

Commit b660369

Browse files
authored
fix postponed annotation when field is overwritten (#228)
1 parent 8c80ae5 commit b660369

File tree

5 files changed

+44
-10
lines changed

5 files changed

+44
-10
lines changed

simple_parsing/annotation_utils/get_field_annotations.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import typing
66
from contextlib import contextmanager
77
from dataclasses import InitVar
8+
from itertools import dropwhile
89
from logging import getLogger as get_logger
910
from typing import Any, Dict, Iterator, Optional, get_type_hints
1011

@@ -190,13 +191,14 @@ def get_field_type_from_annotations(some_class: type, field_name: str) -> type:
190191
if frame is not None:
191192
local_ns.update(frame.f_locals)
192193

193-
# Get the global_ns in the module starting from the deepest base until the module where the field_name is defined.
194+
# Get the global_ns in the module starting from the deepest base until the module with the field_name last definition.
194195
global_ns = {}
195-
for base_cls in reversed(some_class.mro()):
196+
classes_to_iterate = list(dropwhile(
197+
lambda cls: field_name not in getattr(cls, "__annotations__", {}),
198+
some_class.mro()
199+
))
200+
for base_cls in reversed(classes_to_iterate):
196201
global_ns.update(sys.modules[base_cls.__module__].__dict__)
197-
annotations = getattr(base_cls, "__annotations__", None)
198-
if annotations and field_name in annotations:
199-
break
200202

201203
try:
202204
with _initvar_patcher():
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from ..test_utils import TestSetup
6+
from .overwrite_base import Base, ParamCls
7+
8+
9+
@dataclass
10+
class ParamClsSubclass(ParamCls):
11+
v: bool
12+
13+
14+
@dataclass
15+
class Subclass(Base, TestSetup):
16+
attribute: ParamClsSubclass

test/postponed_annotations/overwrite_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55

66
@dataclass
7-
class Foo:
7+
class ParamCls:
88
...
99

1010

1111
@dataclass
1212
class Base:
13-
a: Foo
13+
attribute: ParamCls

test/postponed_annotations/overwrite_subclass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88

99
@dataclass
10-
class Foo:
10+
class ParamCls:
1111
something_else: bool = True
1212

1313

1414
@dataclass
1515
class Subclass(Base, TestSetup):
16-
other_attribute: Foo
16+
other_attribute: ParamCls

test/postponed_annotations/test_postponed_annotations.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,27 @@ def test_postponed_annotations_with_multi_depth_inherits_2():
5757

5858

5959
def test_overwrite_base():
60+
"""Test that postponed annotations don't break types with the same name in multiple files."""
6061
import test.postponed_annotations.overwrite_base as overwrite_base
6162
import test.postponed_annotations.overwrite_subclass as overwrite_subclass
6263

6364
assert overwrite_subclass.Subclass.setup(
6465
"--something_else False"
6566
) == overwrite_subclass.Subclass(
66-
a=overwrite_base.Foo(), other_attribute=overwrite_subclass.Foo(False)
67+
attribute=overwrite_base.ParamCls(),
68+
other_attribute=overwrite_subclass.ParamCls(False),
6769
)
70+
71+
72+
def test_overwrite_field():
73+
"""Test that postponed annotations don't break attribute overwriting in multiple files."""
74+
import test.postponed_annotations.overwrite_base as overwrite_base
75+
import test.postponed_annotations.overwrite_attribute as overwrite_attribute
76+
77+
instance = overwrite_attribute.Subclass.setup("--v True")
78+
assert type(instance.attribute) != overwrite_base.ParamCls, (
79+
"attribute type from Base class correctly ignored"
80+
)
81+
assert instance == overwrite_attribute.Subclass(
82+
attribute=overwrite_attribute.ParamClsSubclass(True)
83+
), "parsed attribute value is correct"

0 commit comments

Comments
 (0)