11from __future__ import annotations
22
33from itertools import chain
4- from typing import Any , ClassVar , cast
4+ from typing import Any , ClassVar , OrderedDict , cast
55
66from attr import define , evolve
77
1515
1616@define
1717class DiscriminatorDefinition :
18+ """Represents a discriminator that can optionally be specified for a union type.
19+
20+ Normally, a UnionProperty has either zero or one of these. However, a nested union
21+ could have more than one, as we accumulate all the discriminators when we flatten
22+ out the nested schemas. For example:
23+
24+ anyOf:
25+ - anyOf:
26+ - $ref: "#/components/schemas/Cat"
27+ - $ref: "#/components/schemas/Dog"
28+ discriminator:
29+ propertyName: mammalType
30+ - anyOf:
31+ - $ref: "#/components/schemas/Condor"
32+ - $ref: "#/components/schemas/Chicken"
33+ discriminator:
34+ propertyName: birdType
35+
36+ In this example there are four schemas and two discriminators. The deserializer
37+ logic will check for the mammalType property first, then birdType.
38+ """
1839 property_name : str
1940 value_to_model_map : dict [str , PropertyProtocol ]
2041 # Every value in the map is really a ModelProperty, but this avoids circular imports
@@ -75,7 +96,7 @@ def build(
7596 return PropertyError (detail = f"Invalid property in union { name } " , data = sub_prop_data ), schemas
7697 sub_properties .append (sub_prop )
7798
78- sub_properties , discriminators_list = _flatten_union_properties (sub_properties )
99+ sub_properties , discriminators_from_nested_unions = _flatten_union_properties (sub_properties )
79100
80101 prop = UnionProperty (
81102 name = name ,
@@ -92,15 +113,14 @@ def build(
92113 return default_or_error , schemas
93114 prop = evolve (prop , default = default_or_error )
94115
116+ all_discriminators = discriminators_from_nested_unions
95117 if data .discriminator :
96118 discriminator_or_error = _parse_discriminator (data .discriminator , sub_properties , schemas )
97119 if isinstance (discriminator_or_error , PropertyError ):
98120 return discriminator_or_error , schemas
99- discriminators_list = [discriminator_or_error , * discriminators_list ]
100- if discriminators_list :
101- if error := _validate_discriminators (discriminators_list ):
102- return error , schemas
103- prop = evolve (prop , discriminators = discriminators_list )
121+ all_discriminators = [discriminator_or_error , * all_discriminators ]
122+ if all_discriminators :
123+ prop = evolve (prop , discriminators = all_discriminators )
104124
105125 return prop , schemas
106126
@@ -227,15 +247,33 @@ def _parse_discriminator(
227247
228248 # See: https://spec.openapis.org/oas/v3.1.0.html#discriminator-object
229249
230- def _find_top_level_model (matching_model : ModelProperty ) -> ModelProperty | None :
231- # This is needed because, when we built the union list, $refs were changed into a copy of
232- # the type they referred to, without preserving the original name. We need to know that
233- # every type in the discriminator is a $ref to a top-level type and we need its name.
234- for prop in schemas .classes_by_reference .values ():
235- if isinstance (prop , ModelProperty ):
236- if prop .class_info == matching_model .class_info :
237- return prop
238- return None
250+ # Conditions that must be true when there is a discriminator:
251+ # 1. Every type in the anyOf/oneOf list must be a $ref to a named schema, such as
252+ # #/components/schemas/X, rather than an inline schema. This is important because
253+ # we may need to use the schema's simple name (X).
254+ # 2. There must be a propertyName, representing a property that exists in every
255+ # schema in that list (although we can't currently enforce the latter condition,
256+ # because those properties haven't been parsed yet at this point.)
257+ #
258+ # There *may* also be a mapping of lookup values (the possible values of the property)
259+ # to schemas. Schemas can be referenced either by a full path or a name:
260+ # mapping:
261+ # value_for_a: "#/components/schemas/ModelA"
262+ # value_for_b: ModelB # equivalent to "#/components/schemas/ModelB"
263+ #
264+ # For any type that isn't specified in the mapping (or if the whole mapping is omitted)
265+ # the default lookup value for each schema is the same as the schema name. So this--
266+ # mapping:
267+ # value_for_a: "#/components/schemas/ModelA"
268+ # --is exactly equivalent to this:
269+ # discriminator:
270+ # propertyName: modelType
271+ # mapping:
272+ # value_for_a: "#/components/schemas/ModelA"
273+ # ModelB: "#/components/schemas/ModelB"
274+
275+ def _get_model_name (model : ModelProperty ) -> str | None :
276+ return get_reference_simple_name (model .ref_path ) if model .ref_path else None
239277
240278 model_types_by_name : dict [str , PropertyProtocol ] = {}
241279 for model in subtypes :
@@ -245,59 +283,32 @@ def _find_top_level_model(matching_model: ModelProperty) -> ModelProperty | None
245283 return PropertyError (
246284 detail = "All schema variants must be objects when using a discriminator" ,
247285 )
248- top_level_model = _find_top_level_model (model )
249- if not top_level_model :
286+ name = _get_model_name (model )
287+ if not name :
250288 return PropertyError (
251289 detail = "Inline schema declarations are not allowed when using a discriminator" ,
252290 )
253- name = top_level_model .name
254- if name .startswith ("/components/schemas/" ):
255- name = get_reference_simple_name (name )
256- model_types_by_name [name ] = top_level_model
257-
258- # The discriminator can specify an explicit mapping of values to types, but it doesn't
259- # have to; the default behavior is that the value for each type is simply its name.
260- mapping : dict [str , PropertyProtocol ] = model_types_by_name .copy ()
291+ model_types_by_name [name ] = model
292+
293+ mapping : dict [str , PropertyProtocol ] = OrderedDict () # use ordered dict for test determinacy
294+ unspecified_models = list (model_types_by_name .values ())
261295 if data .mapping :
262296 for discriminator_value , model_ref in data .mapping .items ():
263- ref_path = parse_reference_path (
264- model_ref if model_ref .startswith ("#/components/schemas/" ) else f"#/components/schemas/{ model_ref } "
265- )
266- if isinstance (ref_path , ParseError ) or ref_path not in schemas .classes_by_reference :
267- return PropertyError (detail = f'Invalid reference "{ model_ref } " in discriminator mapping' )
268- name = get_reference_simple_name (ref_path )
269- if not (lookup_model := model_types_by_name .get (name )):
297+ if "/" in model_ref :
298+ ref_path = parse_reference_path (model_ref )
299+ if isinstance (ref_path , ParseError ) or ref_path not in schemas .classes_by_reference :
300+ return PropertyError (detail = f'Invalid reference "{ model_ref } " in discriminator mapping' )
301+ name = get_reference_simple_name (ref_path )
302+ else :
303+ name = model_ref
304+ model = model_types_by_name .get (name )
305+ if not model :
270306 return PropertyError (
271- detail = f'Discriminator mapping referred to "{ model_ref } " which is not one of the schema variants' ,
307+ detail = f'Discriminator mapping referred to "{ name } " which is not one of the schema variants' ,
272308 )
273- for original_value in (name for name , m in model_types_by_name .items () if m == lookup_model ):
274- mapping .pop (original_value )
275- mapping [discriminator_value ] = lookup_model
276- else :
277- mapping = model_types_by_name
278-
309+ mapping [discriminator_value ] = model
310+ unspecified_models .remove (model )
311+ for model in unspecified_models :
312+ if name := _get_model_name (model ):
313+ mapping [name ] = model
279314 return DiscriminatorDefinition (property_name = data .propertyName , value_to_model_map = mapping )
280-
281-
282- def _validate_discriminators (
283- discriminators : list [DiscriminatorDefinition ],
284- ) -> PropertyError | None :
285- from .model_property import ModelProperty
286-
287- prop_names_values_classes = [
288- (discriminator .property_name , key , cast (ModelProperty , model ).class_info .name )
289- for discriminator in discriminators
290- for key , model in discriminator .value_to_model_map .items ()
291- ]
292- for p , v in {(p , v ) for p , v , _ in prop_names_values_classes }:
293- if len ({c for p1 , v1 , c in prop_names_values_classes if (p1 , v1 ) == (p , v )}) > 1 :
294- return PropertyError (f'Discriminator property "{ p } " had more than one schema for value "{ v } "' )
295- return None
296-
297- # TODO: We should also validate that property_name refers to a property that 1. exists,
298- # 2. is required, 3. is a string (in all of these models). However, currently we can't
299- # do that because, at the time this function is called, the ModelProperties within the
300- # union haven't yet been post-processed and so we don't have full information about
301- # their properties. To fix this, we may need to generalize the post-processing phase so
302- # that any Property type, not just ModelProperty, can say it needs post-processing; then
303- # we can defer _validate_discriminators till that phase.
0 commit comments