@@ -8,8 +8,8 @@ use std::borrow::Cow;
88use crate :: build_tools:: py_schema_err;
99use crate :: common:: union:: { Discriminator , SMALL_UNION_THRESHOLD } ;
1010use crate :: definitions:: DefinitionsBuilder ;
11+ use crate :: serializers:: PydanticSerializationUnexpectedValue ;
1112use crate :: tools:: { truncate_safe_repr, SchemaDict } ;
12- use crate :: PydanticSerializationUnexpectedValue ;
1313
1414use super :: {
1515 infer_json_key, infer_serialize, infer_to_python, BuildSerializer , CombinedSerializer , Extra , SerCheck ,
@@ -70,22 +70,23 @@ impl UnionSerializer {
7070
7171impl_py_gc_traverse ! ( UnionSerializer { choices } ) ;
7272
73- fn to_python (
74- value : & Bound < ' _ , PyAny > ,
75- include : Option < & Bound < ' _ , PyAny > > ,
76- exclude : Option < & Bound < ' _ , PyAny > > ,
73+ fn union_serialize < S > (
74+ // if this returns `Ok(Some(v))`, we picked a union variant to serialize,
75+ // Or `Ok(None)` if we couldn't find a suitable variant to serialize
76+ // Finally, `Err(err)` if we encountered errors while trying to serialize
77+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
7778 extra : & Extra ,
7879 choices : & [ CombinedSerializer ] ,
7980 retry_with_lax_check : bool ,
80- ) -> PyResult < PyObject > {
81+ ) -> PyResult < Option < S > > {
8182 // try the serializers in left to right order with error_on fallback=true
8283 let mut new_extra = extra. clone ( ) ;
8384 new_extra. check = SerCheck :: Strict ;
8485 let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
8586
8687 for comb_serializer in choices {
87- match comb_serializer . to_python ( value , include , exclude , & new_extra) {
88- Ok ( v) => return Ok ( v ) ,
88+ match selector ( comb_serializer , & new_extra) {
89+ Ok ( v) => return Ok ( Some ( v ) ) ,
8990 Err ( err) => errors. push ( err) ,
9091 }
9192 }
@@ -94,8 +95,8 @@ fn to_python(
9495 if extra. check != SerCheck :: Strict && retry_with_lax_check {
9596 new_extra. check = SerCheck :: Lax ;
9697 for comb_serializer in choices {
97- if let Ok ( v) = comb_serializer . to_python ( value , include , exclude , & new_extra) {
98- return Ok ( v ) ;
98+ if let Ok ( v) = selector ( comb_serializer , & new_extra) {
99+ return Ok ( Some ( v ) ) ;
99100 }
100101 }
101102 }
@@ -113,94 +114,45 @@ fn to_python(
113114 return Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ;
114115 }
115116
116- infer_to_python ( value , include , exclude , extra )
117+ Ok ( None )
117118}
118119
119- fn json_key < ' a > (
120- key : & ' a Bound < ' _ , PyAny > ,
120+ fn tagged_union_serialize < S > (
121+ discriminator_value : Option < Py < PyAny > > ,
122+ lookup : & HashMap < String , usize > ,
123+ // if this returns `Ok(v)`, we picked a union variant to serialize, where
124+ // `S` is intermediate state which can be passed on to the finalizer
125+ mut selector : impl FnMut ( & CombinedSerializer , & Extra ) -> PyResult < S > ,
121126 extra : & Extra ,
122127 choices : & [ CombinedSerializer ] ,
123128 retry_with_lax_check : bool ,
124- ) -> PyResult < Cow < ' a , str > > {
129+ ) -> PyResult < Option < S > > {
125130 let mut new_extra = extra. clone ( ) ;
126131 new_extra. check = SerCheck :: Strict ;
127- let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
128-
129- for comb_serializer in choices {
130- match comb_serializer. json_key ( key, & new_extra) {
131- Ok ( v) => return Ok ( v) ,
132- Err ( err) => errors. push ( err) ,
133- }
134- }
135132
136- // If extra.check is SerCheck::Strict, we're in a nested union
137- if extra. check != SerCheck :: Strict && retry_with_lax_check {
138- new_extra. check = SerCheck :: Lax ;
139- for comb_serializer in choices {
140- if let Ok ( v) = comb_serializer. json_key ( key, & new_extra) {
141- return Ok ( v) ;
142- }
143- }
144- }
145-
146- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
147- if extra. check == SerCheck :: None {
148- for err in & errors {
149- extra. warnings . custom_warning ( err. to_string ( ) ) ;
150- }
151- }
152- // Otherwise, if we've encountered errors, return them to the parent union, which should take
153- // care of the formatting for us
154- else if !errors. is_empty ( ) {
155- let message = errors. iter ( ) . map ( ToString :: to_string) . collect :: < Vec < _ > > ( ) . join ( "\n " ) ;
156- return Err ( PydanticSerializationUnexpectedValue :: new_err ( Some ( message) ) ) ;
157- }
158- infer_json_key ( key, extra)
159- }
160-
161- #[ allow( clippy:: too_many_arguments) ]
162- fn serde_serialize < S : serde:: ser:: Serializer > (
163- value : & Bound < ' _ , PyAny > ,
164- serializer : S ,
165- include : Option < & Bound < ' _ , PyAny > > ,
166- exclude : Option < & Bound < ' _ , PyAny > > ,
167- extra : & Extra ,
168- choices : & [ CombinedSerializer ] ,
169- retry_with_lax_check : bool ,
170- ) -> Result < S :: Ok , S :: Error > {
171- let py = value. py ( ) ;
172- let mut new_extra = extra. clone ( ) ;
173- new_extra. check = SerCheck :: Strict ;
174- let mut errors: SmallVec < [ PyErr ; SMALL_UNION_THRESHOLD ] > = SmallVec :: new ( ) ;
175-
176- for comb_serializer in choices {
177- match comb_serializer. to_python ( value, include, exclude, & new_extra) {
178- Ok ( v) => return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ,
179- Err ( err) => errors. push ( err) ,
180- }
181- }
182-
183- // If extra.check is SerCheck::Strict, we're in a nested union
184- if extra. check != SerCheck :: Strict && retry_with_lax_check {
185- new_extra. check = SerCheck :: Lax ;
186- for comb_serializer in choices {
187- if let Ok ( v) = comb_serializer. to_python ( value, include, exclude, & new_extra) {
188- return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ;
133+ if let Some ( tag) = discriminator_value {
134+ let tag_str = tag. to_string ( ) ;
135+ if let Some ( & serializer_index) = lookup. get ( & tag_str) {
136+ let selected_serializer = & choices[ serializer_index] ;
137+
138+ match selector ( selected_serializer, & new_extra) {
139+ Ok ( v) => return Ok ( Some ( v) ) ,
140+ Err ( _) => {
141+ if retry_with_lax_check {
142+ new_extra. check = SerCheck :: Lax ;
143+ if let Ok ( v) = selector ( selected_serializer, & new_extra) {
144+ return Ok ( Some ( v) ) ;
145+ }
146+ }
147+ }
189148 }
190149 }
191150 }
192151
193- // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise the warnings
194- if extra. check == SerCheck :: None {
195- for err in & errors {
196- extra. warnings . custom_warning ( err. to_string ( ) ) ;
197- }
198- } else {
199- // NOTE: if this function becomes recursive at some point, an `Err(_)` containing the errors
200- // will have to be returned here
201- }
202-
203- infer_serialize ( value, serializer, include, exclude, extra)
152+ // if we haven't returned at this point, we should fallback to the union serializer
153+ // which preserves the historical expectation that we do our best with serialization
154+ // even if that means we resort to inference
155+ union_serialize ( selector, extra, choices, retry_with_lax_check)
204156}
205157
206158impl TypeSerializer for UnionSerializer {
@@ -211,18 +163,23 @@ impl TypeSerializer for UnionSerializer {
211163 exclude : Option < & Bound < ' _ , PyAny > > ,
212164 extra : & Extra ,
213165 ) -> PyResult < PyObject > {
214- to_python (
215- value,
216- include,
217- exclude,
166+ union_serialize (
167+ |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
218168 extra,
219169 & self . choices ,
220170 self . retry_with_lax_check ( ) ,
221- )
171+ ) ?
172+ . map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
222173 }
223174
224175 fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
225- json_key ( key, extra, & self . choices , self . retry_with_lax_check ( ) )
176+ union_serialize (
177+ |comb_serializer, new_extra| comb_serializer. json_key ( key, new_extra) ,
178+ extra,
179+ & self . choices ,
180+ self . retry_with_lax_check ( ) ,
181+ ) ?
182+ . map_or_else ( || infer_json_key ( key, extra) , Ok )
226183 }
227184
228185 fn serde_serialize < S : serde:: ser:: Serializer > (
@@ -233,15 +190,16 @@ impl TypeSerializer for UnionSerializer {
233190 exclude : Option < & Bound < ' _ , PyAny > > ,
234191 extra : & Extra ,
235192 ) -> Result < S :: Ok , S :: Error > {
236- serde_serialize (
237- value,
238- serializer,
239- include,
240- exclude,
193+ match union_serialize (
194+ |comb_serializer, new_extra| comb_serializer. to_python ( value, include, exclude, new_extra) ,
241195 extra,
242196 & self . choices ,
243197 self . retry_with_lax_check ( ) ,
244- )
198+ ) {
199+ Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
200+ Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
201+ Err ( err) => Err ( serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ,
202+ }
245203 }
246204
247205 fn get_name ( & self ) -> & str {
@@ -309,62 +267,29 @@ impl TypeSerializer for TaggedUnionSerializer {
309267 exclude : Option < & Bound < ' _ , PyAny > > ,
310268 extra : & Extra ,
311269 ) -> PyResult < PyObject > {
312- let mut new_extra = extra. clone ( ) ;
313- new_extra. check = SerCheck :: Strict ;
314-
315- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
316- let tag_str = tag. to_string ( ) ;
317- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
318- let serializer = & self . choices [ serializer_index] ;
319-
320- match serializer. to_python ( value, include, exclude, & new_extra) {
321- Ok ( v) => return Ok ( v) ,
322- Err ( _) => {
323- if self . retry_with_lax_check ( ) {
324- new_extra. check = SerCheck :: Lax ;
325- if let Ok ( v) = serializer. to_python ( value, include, exclude, & new_extra) {
326- return Ok ( v) ;
327- }
328- }
329- }
330- }
331- }
332- }
333-
334- to_python (
335- value,
336- include,
337- exclude,
270+ tagged_union_serialize (
271+ self . get_discriminator_value ( value, extra) ,
272+ & self . lookup ,
273+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
274+ comb_serializer. to_python ( value, include, exclude, new_extra)
275+ } ,
338276 extra,
339277 & self . choices ,
340278 self . retry_with_lax_check ( ) ,
341- )
279+ ) ?
280+ . map_or_else ( || infer_to_python ( value, include, exclude, extra) , Ok )
342281 }
343282
344283 fn json_key < ' a > ( & self , key : & ' a Bound < ' _ , PyAny > , extra : & Extra ) -> PyResult < Cow < ' a , str > > {
345- let mut new_extra = extra. clone ( ) ;
346- new_extra. check = SerCheck :: Strict ;
347-
348- if let Some ( tag) = self . get_discriminator_value ( key, extra) {
349- let tag_str = tag. to_string ( ) ;
350- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
351- let serializer = & self . choices [ serializer_index] ;
352-
353- match serializer. json_key ( key, & new_extra) {
354- Ok ( v) => return Ok ( v) ,
355- Err ( _) => {
356- if self . retry_with_lax_check ( ) {
357- new_extra. check = SerCheck :: Lax ;
358- if let Ok ( v) = serializer. json_key ( key, & new_extra) {
359- return Ok ( v) ;
360- }
361- }
362- }
363- }
364- }
365- }
366-
367- json_key ( key, extra, & self . choices , self . retry_with_lax_check ( ) )
284+ tagged_union_serialize (
285+ self . get_discriminator_value ( key, extra) ,
286+ & self . lookup ,
287+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | comb_serializer. json_key ( key, new_extra) ,
288+ extra,
289+ & self . choices ,
290+ self . retry_with_lax_check ( ) ,
291+ ) ?
292+ . map_or_else ( || infer_json_key ( key, extra) , Ok )
368293 }
369294
370295 fn serde_serialize < S : serde:: ser:: Serializer > (
@@ -375,38 +300,20 @@ impl TypeSerializer for TaggedUnionSerializer {
375300 exclude : Option < & Bound < ' _ , PyAny > > ,
376301 extra : & Extra ,
377302 ) -> Result < S :: Ok , S :: Error > {
378- let py = value. py ( ) ;
379- let mut new_extra = extra. clone ( ) ;
380- new_extra. check = SerCheck :: Strict ;
381-
382- if let Some ( tag) = self . get_discriminator_value ( value, extra) {
383- let tag_str = tag. to_string ( ) ;
384- if let Some ( & serializer_index) = self . lookup . get ( & tag_str) {
385- let selected_serializer = & self . choices [ serializer_index] ;
386-
387- match selected_serializer. to_python ( value, include, exclude, & new_extra) {
388- Ok ( v) => return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ,
389- Err ( _) => {
390- if self . retry_with_lax_check ( ) {
391- new_extra. check = SerCheck :: Lax ;
392- if let Ok ( v) = selected_serializer. to_python ( value, include, exclude, & new_extra) {
393- return infer_serialize ( v. bind ( py) , serializer, None , None , extra) ;
394- }
395- }
396- }
397- }
398- }
399- }
400-
401- serde_serialize (
402- value,
403- serializer,
404- include,
405- exclude,
303+ match tagged_union_serialize (
304+ None ,
305+ & self . lookup ,
306+ |comb_serializer : & CombinedSerializer , new_extra : & Extra | {
307+ comb_serializer. to_python ( value, include, exclude, new_extra)
308+ } ,
406309 extra,
407310 & self . choices ,
408311 self . retry_with_lax_check ( ) ,
409- )
312+ ) {
313+ Ok ( Some ( v) ) => return infer_serialize ( v. bind ( value. py ( ) ) , serializer, None , None , extra) ,
314+ Ok ( None ) => infer_serialize ( value, serializer, include, exclude, extra) ,
315+ Err ( err) => Err ( serde:: ser:: Error :: custom ( err. to_string ( ) ) ) ,
316+ }
410317 }
411318
412319 fn get_name ( & self ) -> & str {
0 commit comments