11from collections import deque
22from copy import deepcopy
3- from dataclasses import fields , is_dataclass
3+ from dataclasses import dataclass , fields , is_dataclass
44from functools import reduce
55from inspect import isclass , unwrap
6- from typing import Any , Dict , List , Literal , Protocol , Tuple , Type , TypeVar , Union
6+ from typing import Any , Dict , List , Literal , Optional , Protocol , Tuple , Type , TypeVar , Union , cast
77
88from catalystwan .core .exceptions import (
99 CatalystwanModelInputException ,
1010 CatalystwanModelValidationError ,
1111)
12+ from catalystwan .core .models .utils import count_matching_keys
1213from catalystwan .core .types import MODEL_TYPES , AliasPath , DataclassInstance
1314from typing_extensions import Annotated , get_args , get_origin , get_type_hints
1415
@@ -19,6 +20,13 @@ class ValueExtractorCallable(Protocol):
1920 def __call__ (self , field_value : Any ) -> Any : ...
2021
2122
23+ @dataclass
24+ class ExtractedValue :
25+ value : Any
26+ exact_match : bool
27+ matched_keys : Optional [int ] = None
28+
29+
2230class ModelDeserializer :
2331 def __init__ (self , model : Type [T ]) -> None :
2432 self .model = model
@@ -57,67 +65,91 @@ def __check_errors(self):
5765 message += f"{ exc } \n "
5866 raise CatalystwanModelValidationError (message )
5967
60- def __is_optional (self , t : Any ) -> bool :
61- if get_origin (t ) is Union and type (None ) in get_args (t ):
62- return True
63- return False
64-
65- def __extract_type (self , field_type : Any , field_value : Any , field_name : str ) -> Any :
68+ def __extract_type (self , field_type : Any , field_value : Any , field_name : str ) -> ExtractedValue :
6669 origin = get_origin (field_type )
6770 # check for simple types and classes
6871 if origin is None :
69- if field_type is Any :
70- return field_value
71- if isinstance (field_value , field_type ):
72- return field_value
72+ if field_type is Any or isinstance (field_value , field_type ):
73+ return ExtractedValue (value = field_value , exact_match = True )
74+ # Do not cast bool values
75+ elif field_type is bool :
76+ ...
77+ # False/Empty values (like empty string or list) can match to None
78+ elif field_type is type (None ):
79+ if not field_value :
80+ return ExtractedValue (value = None , exact_match = False )
7381 elif is_dataclass (field_type ):
74- assert isinstance (field_type , type )
75- return deserialize (field_type , ** field_value )
82+ model_instance = deserialize (
83+ cast (Type [DataclassInstance ], field_type ), ** field_value
84+ )
85+ return ExtractedValue (
86+ value = model_instance ,
87+ exact_match = False ,
88+ matched_keys = count_matching_keys (model_instance , field_value ),
89+ )
7690 elif isclass (unwrap (field_type )):
7791 if isinstance (field_value , dict ):
78- return field_type (** field_value )
92+ return ExtractedValue ( value = field_type (** field_value ), exact_match = False )
7993 else :
8094 try :
81- return field_type (field_value )
95+ return ExtractedValue ( value = field_type (field_value ), exact_match = False )
8296 except ValueError :
8397 raise CatalystwanModelInputException (
8498 f"Unable to match or cast input value for { field_name } [expected_type={ unwrap (field_type )} , input={ field_value } , input_type={ type (field_value )} ]"
8599 )
100+ # List is an exact match only if all of its elements are
86101 elif origin is list :
87102 if isinstance (field_value , list ):
88- return [
89- self .__extract_type (get_args (field_type )[0 ], value , field_name )
90- for value in field_value
91- ]
92- elif self .__is_optional (field_type ):
93- if field_value is None :
94- return None
95- else :
96- try :
97- return self .__extract_type (get_args (field_type )[0 ], field_value , field_name )
98- except CatalystwanModelInputException as e :
99- if not field_value :
100- return None
101- raise e
103+ values = []
104+ exact_match = True
105+ for value in field_value :
106+ extracted_value = self .__extract_type (
107+ get_args (field_type )[0 ], value , field_name
108+ )
109+ values .append (extracted_value .value )
110+ if not extracted_value .exact_match :
111+ exact_match = False
112+ return ExtractedValue (value = values , exact_match = exact_match )
102113 elif origin is Literal :
103114 for arg in get_args (field_type ):
104115 try :
105116 if type (arg )(field_value ) == arg :
106- return type (arg )(field_value )
117+ return ExtractedValue (
118+ value = type (arg )(field_value ), exact_match = type (arg ) is type (field_value )
119+ )
107120 except Exception :
108121 continue
109122 elif origin is Annotated :
110123 validator , caster = field_type .__metadata__
111124 if validator (field_value ):
112- return field_value
113- return caster (field_value )
114- # TODO: Currently, casting is done left-to-right. Searching deeper for a better match may be the way to go.
125+ return ExtractedValue (value = field_value , exact_match = True )
126+ return ExtractedValue (value = caster (field_value ), exact_match = False )
127+ # When parsing Unions, try to find the best match. Currently, it involves:
128+ # 1. Finding the exact match
129+ # 2. If not found, favors dataclasses - sorted by number of matched keys, then None values
130+ # 3. If no dataclasses are present, return the leftmost matched argument
115131 elif origin is Union :
132+ matches : List [ExtractedValue ] = []
116133 for arg in get_args (field_type ):
117134 try :
118- return self .__extract_type (arg , field_value , field_name )
135+ extracted_value = self .__extract_type (arg , field_value , field_name )
136+ # exact match, return
137+ if extracted_value .exact_match :
138+ return extracted_value
139+ else :
140+ matches .append (extracted_value )
119141 except Exception :
120142 continue
143+ # Only one element matched, return
144+ if len (matches ) == 1 :
145+ return matches [0 ]
146+ # Only non-exact matches left, sort and return first element
147+ elif len (matches ) > 1 :
148+ matches .sort (
149+ key = lambda x : (x .matched_keys is not None , x .matched_keys , x .value is None ),
150+ reverse = True ,
151+ )
152+ return matches [0 ]
121153 # Correct type not found, add exception
122154 raise CatalystwanModelInputException (
123155 f"Unable to match or cast input value for { field_name } [expected_type={ unwrap (field_type )} , input={ field_value } , input_type={ type (field_value )} ]"
@@ -130,7 +162,7 @@ def __transform_model_input(
130162 kwargs_copy = deepcopy (kwargs )
131163 new_args = []
132164 new_kwargs = {}
133- field_types = get_type_hints (cls )
165+ field_types = get_type_hints (cls , include_extras = True )
134166 for field in fields (cls ):
135167 if not field .init :
136168 continue
@@ -140,7 +172,9 @@ def __transform_model_input(
140172 field_value = args_copy .popleft ()
141173 try :
142174 new_args .append (
143- self .__extract_type (field_type , value_extractor (field_value ), field .name )
175+ self .__extract_type (
176+ field_type , value_extractor (field_value ), field .name
177+ ).value
144178 )
145179 except (
146180 CatalystwanModelInputException ,
@@ -164,7 +198,7 @@ def __transform_model_input(
164198 try :
165199 new_kwargs [field .name ] = self .__extract_type (
166200 field_type , value_extractor (field_value ), field .name
167- )
201+ ). value
168202 except (
169203 CatalystwanModelInputException ,
170204 CatalystwanModelValidationError ,
0 commit comments