Skip to content

Commit 32a7ac9

Browse files
authored
Change union for best-match in deserialization (#13)
1 parent b5b1076 commit 32a7ac9

File tree

4 files changed

+203
-61
lines changed

4 files changed

+203
-61
lines changed

packages/catalystwan-core/src/catalystwan/core/models/deserialize.py

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from collections import deque
22
from copy import deepcopy
3-
from dataclasses import fields, is_dataclass
3+
from dataclasses import dataclass, fields, is_dataclass
44
from functools import reduce
55
from inspect import isclass, unwrap
6-
from typing import Any, Dict, List, Literal, Protocol, Tuple, Type, TypeVar, Union
6+
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Type, TypeVar, Union, cast
77

88
from catalystwan.core.exceptions import (
99
CatalystwanModelInputException,
1010
CatalystwanModelValidationError,
1111
)
12+
from catalystwan.core.models.utils import count_matching_keys
1213
from catalystwan.core.types import MODEL_TYPES, AliasPath, DataclassInstance
1314
from typing_extensions import Annotated, get_args, get_origin, get_type_hints
1415

@@ -19,6 +20,13 @@ class ValueExtractorCallable(Protocol):
1920
def __call__(self, field_value: Any) -> Any: ...
2021

2122

23+
@dataclass
24+
class ExtractedValue:
25+
value: Any
26+
exact_match: bool
27+
matched_keys: Optional[int] = None
28+
29+
2230
class ModelDeserializer:
2331
def __init__(self, model: Type[T]) -> None:
2432
self.model = model
@@ -57,67 +65,91 @@ def __check_errors(self):
5765
message += f"{exc}\n"
5866
raise CatalystwanModelValidationError(message)
5967

60-
def __is_optional(self, t: Any) -> bool:
61-
if get_origin(t) is Union and type(None) in get_args(t):
62-
return True
63-
return False
64-
65-
def __extract_type(self, field_type: Any, field_value: Any, field_name: str) -> Any:
68+
def __extract_type(self, field_type: Any, field_value: Any, field_name: str) -> ExtractedValue:
6669
origin = get_origin(field_type)
6770
# check for simple types and classes
6871
if origin is None:
69-
if field_type is Any:
70-
return field_value
71-
if isinstance(field_value, field_type):
72-
return field_value
72+
if field_type is Any or isinstance(field_value, field_type):
73+
return ExtractedValue(value=field_value, exact_match=True)
74+
# Do not cast bool values
75+
elif field_type is bool:
76+
...
77+
# False/Empty values (like empty string or list) can match to None
78+
elif field_type is type(None):
79+
if not field_value:
80+
return ExtractedValue(value=None, exact_match=False)
7381
elif is_dataclass(field_type):
74-
assert isinstance(field_type, type)
75-
return deserialize(field_type, **field_value)
82+
model_instance = deserialize(
83+
cast(Type[DataclassInstance], field_type), **field_value
84+
)
85+
return ExtractedValue(
86+
value=model_instance,
87+
exact_match=False,
88+
matched_keys=count_matching_keys(model_instance, field_value),
89+
)
7690
elif isclass(unwrap(field_type)):
7791
if isinstance(field_value, dict):
78-
return field_type(**field_value)
92+
return ExtractedValue(value=field_type(**field_value), exact_match=False)
7993
else:
8094
try:
81-
return field_type(field_value)
95+
return ExtractedValue(value=field_type(field_value), exact_match=False)
8296
except ValueError:
8397
raise CatalystwanModelInputException(
8498
f"Unable to match or cast input value for {field_name} [expected_type={unwrap(field_type)}, input={field_value}, input_type={type(field_value)}]"
8599
)
100+
# List is an exact match only if all of its elements are
86101
elif origin is list:
87102
if isinstance(field_value, list):
88-
return [
89-
self.__extract_type(get_args(field_type)[0], value, field_name)
90-
for value in field_value
91-
]
92-
elif self.__is_optional(field_type):
93-
if field_value is None:
94-
return None
95-
else:
96-
try:
97-
return self.__extract_type(get_args(field_type)[0], field_value, field_name)
98-
except CatalystwanModelInputException as e:
99-
if not field_value:
100-
return None
101-
raise e
103+
values = []
104+
exact_match = True
105+
for value in field_value:
106+
extracted_value = self.__extract_type(
107+
get_args(field_type)[0], value, field_name
108+
)
109+
values.append(extracted_value.value)
110+
if not extracted_value.exact_match:
111+
exact_match = False
112+
return ExtractedValue(value=values, exact_match=exact_match)
102113
elif origin is Literal:
103114
for arg in get_args(field_type):
104115
try:
105116
if type(arg)(field_value) == arg:
106-
return type(arg)(field_value)
117+
return ExtractedValue(
118+
value=type(arg)(field_value), exact_match=type(arg) is type(field_value)
119+
)
107120
except Exception:
108121
continue
109122
elif origin is Annotated:
110123
validator, caster = field_type.__metadata__
111124
if validator(field_value):
112-
return field_value
113-
return caster(field_value)
114-
# TODO: Currently, casting is done left-to-right. Searching deeper for a better match may be the way to go.
125+
return ExtractedValue(value=field_value, exact_match=True)
126+
return ExtractedValue(value=caster(field_value), exact_match=False)
127+
# When parsing Unions, try to find the best match. Currently, it involves:
128+
# 1. Finding the exact match
129+
# 2. If not found, favors dataclasses - sorted by number of matched keys, then None values
130+
# 3. If no dataclasses are present, return the leftmost matched argument
115131
elif origin is Union:
132+
matches: List[ExtractedValue] = []
116133
for arg in get_args(field_type):
117134
try:
118-
return self.__extract_type(arg, field_value, field_name)
135+
extracted_value = self.__extract_type(arg, field_value, field_name)
136+
# exact match, return
137+
if extracted_value.exact_match:
138+
return extracted_value
139+
else:
140+
matches.append(extracted_value)
119141
except Exception:
120142
continue
143+
# Only one element matched, return
144+
if len(matches) == 1:
145+
return matches[0]
146+
# Only non-exact matches left, sort and return first element
147+
elif len(matches) > 1:
148+
matches.sort(
149+
key=lambda x: (x.matched_keys is not None, x.matched_keys, x.value is None),
150+
reverse=True,
151+
)
152+
return matches[0]
121153
# Correct type not found, add exception
122154
raise CatalystwanModelInputException(
123155
f"Unable to match or cast input value for {field_name} [expected_type={unwrap(field_type)}, input={field_value}, input_type={type(field_value)}]"
@@ -130,7 +162,7 @@ def __transform_model_input(
130162
kwargs_copy = deepcopy(kwargs)
131163
new_args = []
132164
new_kwargs = {}
133-
field_types = get_type_hints(cls)
165+
field_types = get_type_hints(cls, include_extras=True)
134166
for field in fields(cls):
135167
if not field.init:
136168
continue
@@ -140,7 +172,9 @@ def __transform_model_input(
140172
field_value = args_copy.popleft()
141173
try:
142174
new_args.append(
143-
self.__extract_type(field_type, value_extractor(field_value), field.name)
175+
self.__extract_type(
176+
field_type, value_extractor(field_value), field.name
177+
).value
144178
)
145179
except (
146180
CatalystwanModelInputException,
@@ -164,7 +198,7 @@ def __transform_model_input(
164198
try:
165199
new_kwargs[field.name] = self.__extract_type(
166200
field_type, value_extractor(field_value), field.name
167-
)
201+
).value
168202
except (
169203
CatalystwanModelInputException,
170204
CatalystwanModelValidationError,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from dataclasses import is_dataclass
2+
from typing import TypeVar, cast
3+
4+
from catalystwan.core.types import DataclassInstance
5+
6+
DataclassType = TypeVar("DataclassType", bound=DataclassInstance)
7+
8+
9+
def count_matching_keys(model: DataclassType, model_payload: dict):
10+
matched_keys = 0
11+
for key, value in model_payload.items():
12+
try:
13+
model_value = getattr(model, key)
14+
matched_keys += 1
15+
if is_dataclass(model_value) and isinstance(value, dict):
16+
matched_keys += count_matching_keys(cast(DataclassType, model_value), value)
17+
elif (
18+
isinstance(model_value, list)
19+
and all([is_dataclass(element) for element in model_value])
20+
and isinstance(value, list)
21+
):
22+
for model_v, input_v in zip(model_value, value):
23+
matched_keys += count_matching_keys(model_v, input_v)
24+
except AttributeError:
25+
continue
26+
27+
return matched_keys

packages/catalystwan-core/src/catalystwan/core/request_adapter.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from catalystwan.core.models.deserialize import deserialize
1616
from catalystwan.core.models.serialize import serialize
17+
from catalystwan.core.models.utils import count_matching_keys
1718
from catalystwan.core.types import DataclassInstance
1819
from typing_extensions import get_args, get_origin
1920

@@ -213,34 +214,12 @@ class ModelReturn:
213214
# return model that matches best with the input
214215
valid_models.sort(
215216
key=lambda x: (
216-
self.__count_matching_keys(x.model, cast(dict, x.payload.data)),
217+
count_matching_keys(x.model, cast(dict, x.payload.data)),
217218
x.payload.priority,
218219
),
219220
reverse=True,
220221
)
221222
return valid_models[0].model
222223

223-
def __count_matching_keys(self, model: DataclassType, model_payload: dict):
224-
matched_keys = 0
225-
for key, value in model_payload.items():
226-
try:
227-
model_value = getattr(model, key)
228-
matched_keys += 1
229-
if is_dataclass(model_value) and isinstance(value, dict):
230-
matched_keys += self.__count_matching_keys(
231-
cast(DataclassType, model_value), value
232-
)
233-
elif (
234-
isinstance(model_value, list)
235-
and all([is_dataclass(element) for element in model_value])
236-
and isinstance(value, list)
237-
):
238-
for model_v, input_v in zip(model_value, value):
239-
matched_keys += self.__count_matching_keys(model_v, input_v)
240-
except AttributeError:
241-
continue
242-
243-
return matched_keys
244-
245224
def __copy__(self) -> RequestAdapter:
246225
return RequestAdapter(session=copy(self.session), logger=self.logger)

packages/catalystwan-core/tests/test_model_deserialize.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from dataclasses import dataclass
22
from ipaddress import IPv4Address, IPv6Address
33
from typing import List, Literal, Optional, Union
4+
from uuid import UUID
45

6+
import pytest
57
from catalystwan.core.models.deserialize import deserialize
8+
from catalystwan.core.types import Variable
69

710

811
def test_simple_deserialize():
@@ -131,3 +134,102 @@ class Model:
131134
assert m.union_field == IPv4Address("10.0.0.1")
132135
assert m.submodel_field.int_field == 1
133136
assert isinstance(m.submodel_field, Submodel)
137+
138+
139+
@pytest.mark.parametrize(
140+
"value",
141+
[
142+
("1"),
143+
(1),
144+
(1.2),
145+
("True"),
146+
("3a56601d-6132-4aea-98d0-605fa966ad48"),
147+
(UUID("3a56601d-6132-4aea-98d0-605fa966ad48")),
148+
],
149+
)
150+
def test_union_match_identity(value):
151+
@dataclass
152+
class Model:
153+
union_field: Union[str, int, bool, float, UUID]
154+
155+
m = deserialize(Model, union_field=value)
156+
assert m.union_field == value
157+
158+
159+
def test_union_match_optional():
160+
@dataclass
161+
class Model:
162+
union_field: Optional[Union[str, int, bool, float, UUID]] = None
163+
164+
m1 = deserialize(Model)
165+
m2 = deserialize(Model, union_field=None)
166+
m3 = deserialize(Model, union_field=[])
167+
168+
assert m1.union_field is None
169+
assert m2.union_field is None
170+
assert m3.union_field is None
171+
172+
173+
@pytest.mark.parametrize(
174+
"value",
175+
[
176+
("1"),
177+
(1),
178+
("True"),
179+
("3a56601d-6132-4aea-98d0-605fa966ad48"),
180+
(UUID("3a56601d-6132-4aea-98d0-605fa966ad48")),
181+
([1, "2", 3]),
182+
([1.2, True, 1.3]),
183+
],
184+
)
185+
def test_union_match_nested_identity(value):
186+
@dataclass
187+
class Model:
188+
union_field: Union[
189+
str, int, Union[UUID, Union[List[Union[str, int]], List[Union[float, bool]]]]
190+
]
191+
192+
m = deserialize(Model, union_field=value)
193+
194+
assert m.union_field == value
195+
196+
197+
def test_union_match_models():
198+
@dataclass
199+
class Submodel1:
200+
f1: int
201+
202+
@dataclass
203+
class Submodel2:
204+
f1: int
205+
f2: int
206+
207+
@dataclass
208+
class Model:
209+
union_field: Union[str, Submodel1, Submodel2]
210+
211+
m1 = deserialize(Model, **{"union_field": {"f1": 1}})
212+
m2 = deserialize(Model, **{"union_field": {"f1": 1, "f2": 2}})
213+
m3 = deserialize(Model, **{"union_field": {"f1": 1, "f2": 2, "irrelevant_key": 0}})
214+
215+
assert m1.union_field == Submodel1(1)
216+
assert m2.union_field == Submodel2(1, 2)
217+
assert m3.union_field == Submodel2(1, 2)
218+
219+
220+
@pytest.mark.parametrize(
221+
"model_input,expected_value",
222+
[
223+
("1", 1),
224+
("3a56601d-6132-4aea-98d0-605fa966ad48", UUID("3a56601d-6132-4aea-98d0-605fa966ad48")),
225+
("some_string", "{{some_string}}"),
226+
],
227+
)
228+
def test_match_union_cast(model_input, expected_value):
229+
@dataclass
230+
class Model:
231+
union_field: Optional[Union[int, bool, UUID, Variable]]
232+
233+
m = deserialize(Model, union_field=model_input)
234+
235+
assert m.union_field == expected_value

0 commit comments

Comments
 (0)