11from __future__ import annotations
22
33from itertools import chain
4- from typing import Any , ClassVar , OrderedDict , cast
4+ from typing import Any , ClassVar , Mapping , OrderedDict , cast
55
66from attr import define , evolve
77
1616@define
1717class DiscriminatorDefinition :
1818 """Represents a discriminator that can optionally be specified for a union type.
19-
19+
2020 Normally, a UnionProperty has either zero or one of these. However, a nested union
2121 could have more than one, as we accumulate all the discriminators when we flatten
2222 out the nested schemas. For example:
@@ -36,8 +36,9 @@ class DiscriminatorDefinition:
3636 In this example there are four schemas and two discriminators. The deserializer
3737 logic will check for the mammalType property first, then birdType.
3838 """
39+
3940 property_name : str
40- value_to_model_map : dict [str , PropertyProtocol ]
41+ value_to_model_map : Mapping [str , PropertyProtocol ]
4142 # Every value in the map is really a ModelProperty, but this avoids circular imports
4243
4344
@@ -260,7 +261,7 @@ def _parse_discriminator(
260261 # mapping:
261262 # value_for_a: "#/components/schemas/ModelA"
262263 # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB"
263- #
264+ #
264265 # For any type that isn't specified in the mapping (or if the whole mapping is omitted)
265266 # the default lookup value for each schema is the same as the schema name. So this--
266267 # mapping:
@@ -275,7 +276,7 @@ def _parse_discriminator(
275276 def _get_model_name (model : ModelProperty ) -> str | None :
276277 return get_reference_simple_name (model .ref_path ) if model .ref_path else None
277278
278- model_types_by_name : dict [str , PropertyProtocol ] = {}
279+ model_types_by_name : dict [str , ModelProperty ] = {}
279280 for model in subtypes :
280281 # Note, model here can never be a UnionProperty, because we've already done
281282 # flatten_union_properties() before this point.
@@ -290,7 +291,7 @@ def _get_model_name(model: ModelProperty) -> str | None:
290291 )
291292 model_types_by_name [name ] = model
292293
293- mapping : dict [str , PropertyProtocol ] = OrderedDict () # use ordered dict for test determinacy
294+ mapping : dict [str , ModelProperty ] = OrderedDict () # use ordered dict for test determinacy
294295 unspecified_models = list (model_types_by_name .values ())
295296 if data .mapping :
296297 for discriminator_value , model_ref in data .mapping .items ():
@@ -301,13 +302,13 @@ def _get_model_name(model: ModelProperty) -> str | None:
301302 name = get_reference_simple_name (ref_path )
302303 else :
303304 name = model_ref
304- model = model_types_by_name .get (name )
305- if not model :
305+ mapped_model = model_types_by_name .get (name )
306+ if not mapped_model :
306307 return PropertyError (
307308 detail = f'Discriminator mapping referred to "{ name } " which is not one of the schema variants' ,
308309 )
309- mapping [discriminator_value ] = model
310- unspecified_models .remove (model )
310+ mapping [discriminator_value ] = mapped_model
311+ unspecified_models .remove (mapped_model )
311312 for model in unspecified_models :
312313 if name := _get_model_name (model ):
313314 mapping [name ] = model
0 commit comments