2525
2626log = logging .getLogger (__name__ )
2727
28+
2829@attr .s
29- class StringFormat (object ):
30- format = attr .ib ()
30+ class Format (object ):
31+ unmarshal = attr .ib ()
3132 validate = attr .ib ()
3233
3334
@@ -41,10 +42,10 @@ class Schema(object):
4142 }
4243
4344 STRING_FORMAT_CALLABLE_GETTER = {
44- SchemaFormat .NONE : StringFormat (text_type , TypeValidator (text_type )),
45- SchemaFormat .DATE : StringFormat (format_date , TypeValidator (date , exclude = datetime )),
46- SchemaFormat .DATETIME : StringFormat (format_datetime , TypeValidator (datetime )),
47- SchemaFormat .BINARY : StringFormat (binary_type , TypeValidator (binary_type )),
45+ SchemaFormat .NONE : Format (text_type , TypeValidator (text_type )),
46+ SchemaFormat .DATE : Format (format_date , TypeValidator (date , exclude = datetime )),
47+ SchemaFormat .DATETIME : Format (format_datetime , TypeValidator (datetime )),
48+ SchemaFormat .BINARY : Format (binary_type , TypeValidator (binary_type )),
4849 }
4950
5051 TYPE_VALIDATOR_CALLABLE_GETTER = {
@@ -99,7 +100,6 @@ def __init__(
99100
100101 self ._all_required_properties_cache = None
101102 self ._all_optional_properties_cache = None
102- self .custom_formatters = None
103103
104104 def __getitem__ (self , name ):
105105 return self .properties [name ]
@@ -143,25 +143,27 @@ def get_all_required_properties_names(self):
143143
144144 return set (required )
145145
146- def get_cast_mapping (self ):
146+ def get_cast_mapping (self , custom_formatters = None ):
147+ pass_defaults = lambda f : functools .partial (
148+ f , custom_formatters = custom_formatters )
147149 mapping = self .DEFAULT_CAST_CALLABLE_GETTER .copy ()
148150 mapping .update ({
149- SchemaType .STRING : self ._unmarshal_string ,
150- SchemaType .ANY : self ._unmarshal_any ,
151- SchemaType .ARRAY : self ._unmarshal_collection ,
152- SchemaType .OBJECT : self ._unmarshal_object ,
151+ SchemaType .STRING : pass_defaults ( self ._unmarshal_string ) ,
152+ SchemaType .ANY : pass_defaults ( self ._unmarshal_any ) ,
153+ SchemaType .ARRAY : pass_defaults ( self ._unmarshal_collection ) ,
154+ SchemaType .OBJECT : pass_defaults ( self ._unmarshal_object ) ,
153155 })
154156
155157 return defaultdict (lambda : lambda x : x , mapping )
156158
157- def cast (self , value ):
159+ def cast (self , value , custom_formatters = None ):
158160 """Cast value to schema type"""
159161 if value is None :
160162 if not self .nullable :
161163 raise InvalidSchemaValue ("Null value for non-nullable schema" )
162164 return self .default
163165
164- cast_mapping = self .get_cast_mapping ()
166+ cast_mapping = self .get_cast_mapping (custom_formatters = custom_formatters )
165167
166168 if self .type is not SchemaType .STRING and value == '' :
167169 return None
@@ -179,9 +181,7 @@ def unmarshal(self, value, custom_formatters=None):
179181 if self .deprecated :
180182 warnings .warn ("The schema is deprecated" , DeprecationWarning )
181183
182- self .custom_formatters = custom_formatters
183-
184- casted = self .cast (value )
184+ casted = self .cast (value , custom_formatters = custom_formatters )
185185
186186 if casted is None and not self .required :
187187 return None
@@ -194,13 +194,13 @@ def unmarshal(self, value, custom_formatters=None):
194194
195195 return casted
196196
197- def _unmarshal_string (self , value ):
197+ def _unmarshal_string (self , value , custom_formatters = None ):
198198 try :
199199 schema_format = SchemaFormat (self .format )
200200 except ValueError :
201201 msg = "Unsupported {0} format unmarshalling" .format (self .format )
202- if self . custom_formatters is not None :
203- formatstring = self . custom_formatters .get (self .format )
202+ if custom_formatters is not None :
203+ formatstring = custom_formatters .get (self .format )
204204 if formatstring is None :
205205 raise OpenAPISchemaError (msg )
206206 else :
@@ -209,14 +209,14 @@ def _unmarshal_string(self, value):
209209 formatstring = self .STRING_FORMAT_CALLABLE_GETTER [schema_format ]
210210
211211 try :
212- return formatstring .format (value )
212+ return formatstring .unmarshal (value )
213213 except ValueError :
214214 raise InvalidSchemaValue (
215215 "Failed to format value of {0} to {1}" .format (
216216 value , self .format )
217217 )
218218
219- def _unmarshal_any (self , value ):
219+ def _unmarshal_any (self , value , custom_formatters = None ):
220220 types_resolve_order = [
221221 SchemaType .OBJECT , SchemaType .ARRAY , SchemaType .BOOLEAN ,
222222 SchemaType .INTEGER , SchemaType .NUMBER , SchemaType .STRING ,
@@ -233,14 +233,16 @@ def _unmarshal_any(self, value):
233233 raise NoValidSchema (
234234 "No valid schema found for value {0}" .format (value ))
235235
236- def _unmarshal_collection (self , value ):
236+ def _unmarshal_collection (self , value , custom_formatters = None ):
237237 if self .items is None :
238238 raise UndefinedItemsSchema ("Undefined items' schema" )
239239
240- f = functools .partial (self .items .unmarshal , custom_formatters = self .custom_formatters )
240+ f = functools .partial (self .items .unmarshal ,
241+ custom_formatters = custom_formatters )
241242 return list (map (f , value ))
242243
243- def _unmarshal_object (self , value , model_factory = None ):
244+ def _unmarshal_object (self , value , model_factory = None ,
245+ custom_formatters = None ):
244246 if not isinstance (value , (dict , )):
245247 raise InvalidSchemaValue (
246248 "Value of {0} not a dict" .format (value ))
@@ -252,7 +254,7 @@ def _unmarshal_object(self, value, model_factory=None):
252254 for one_of_schema in self .one_of :
253255 try :
254256 found_props = self ._unmarshal_properties (
255- value , one_of_schema )
257+ value , one_of_schema , custom_formatters = custom_formatters )
256258 except OpenAPISchemaError :
257259 pass
258260 else :
@@ -267,11 +269,13 @@ def _unmarshal_object(self, value, model_factory=None):
267269 "Exactly one valid schema should be valid, None found." )
268270
269271 else :
270- properties = self ._unmarshal_properties (value )
272+ properties = self ._unmarshal_properties (
273+ value , custom_formatters = custom_formatters )
271274
272275 return model_factory .create (properties , name = self .model )
273276
274- def _unmarshal_properties (self , value , one_of_schema = None ):
277+ def _unmarshal_properties (self , value , one_of_schema = None ,
278+ custom_formatters = None ):
275279 all_props = self .get_all_properties ()
276280 all_props_names = self .get_all_properties_names ()
277281 all_req_props_names = self .get_all_required_properties_names ()
@@ -293,7 +297,7 @@ def _unmarshal_properties(self, value, one_of_schema=None):
293297 for prop_name in extra_props :
294298 prop_value = value [prop_name ]
295299 properties [prop_name ] = self .additional_properties .unmarshal (
296- prop_value , self . custom_formatters )
300+ prop_value , custom_formatters = custom_formatters )
297301
298302 for prop_name , prop in iteritems (all_props ):
299303 try :
@@ -305,9 +309,11 @@ def _unmarshal_properties(self, value, one_of_schema=None):
305309 if not prop .nullable and not prop .default :
306310 continue
307311 prop_value = prop .default
308- properties [prop_name ] = prop .unmarshal (prop_value , self .custom_formatters )
312+ properties [prop_name ] = prop .unmarshal (
313+ prop_value , custom_formatters = custom_formatters )
309314
310- self ._validate_properties (properties , one_of_schema = one_of_schema )
315+ self ._validate_properties (properties , one_of_schema = one_of_schema ,
316+ custom_formatters = custom_formatters )
311317
312318 return properties
313319
@@ -320,9 +326,12 @@ def get_validator_mapping(self):
320326 SchemaType .NUMBER : self ._validate_number ,
321327 }
322328
323- return defaultdict (lambda : lambda x : x , mapping )
329+ def default (x , ** kw ):
330+ return x
331+
332+ return defaultdict (lambda : default , mapping )
324333
325- def validate (self , value ):
334+ def validate (self , value , custom_formatters = None ):
326335 if value is None :
327336 if not self .nullable :
328337 raise InvalidSchemaValue ("Null value for non-nullable schema" )
@@ -340,11 +349,11 @@ def validate(self, value):
340349 # structure validation
341350 validator_mapping = self .get_validator_mapping ()
342351 validator_callable = validator_mapping [self .type ]
343- validator_callable (value )
352+ validator_callable (value , custom_formatters = custom_formatters )
344353
345354 return value
346355
347- def _validate_collection (self , value ):
356+ def _validate_collection (self , value , custom_formatters = None ):
348357 if self .items is None :
349358 raise OpenAPISchemaError ("Schema for collection not defined" )
350359
@@ -375,7 +384,9 @@ def _validate_collection(self, value):
375384 if self .unique_items and len (set (value )) != len (value ):
376385 raise InvalidSchemaValue ("Value may not contain duplicate items" )
377386
378- return list (map (self .items .validate , value ))
387+ f = functools .partial (self .items .validate ,
388+ custom_formatters = custom_formatters )
389+ return list (map (f , value ))
379390
380391 def _validate_number (self , value ):
381392 if self .minimum is not None :
@@ -408,13 +419,13 @@ def _validate_number(self, value):
408419 value , self .multiple_of )
409420 )
410421
411- def _validate_string (self , value ):
422+ def _validate_string (self , value , custom_formatters = None ):
412423 try :
413424 schema_format = SchemaFormat (self .format )
414425 except ValueError :
415426 msg = "Unsupported {0} format validation" .format (self .format )
416- if self . custom_formatters is not None :
417- formatstring = self . custom_formatters .get (self .format )
427+ if custom_formatters is not None :
428+ formatstring = custom_formatters .get (self .format )
418429 if formatstring is None :
419430 raise OpenAPISchemaError (msg )
420431 else :
@@ -459,14 +470,16 @@ def _validate_string(self, value):
459470
460471 return True
461472
462- def _validate_object (self , value ):
473+ def _validate_object (self , value , custom_formatters = None ):
463474 properties = value .__dict__
464475
465476 if self .one_of :
466477 valid_one_of_schema = None
467478 for one_of_schema in self .one_of :
468479 try :
469- self ._validate_properties (properties , one_of_schema )
480+ self ._validate_properties (
481+ properties , one_of_schema ,
482+ custom_formatters = custom_formatters )
470483 except OpenAPISchemaError :
471484 pass
472485 else :
@@ -481,7 +494,8 @@ def _validate_object(self, value):
481494 "Exactly one valid schema should be valid, None found." )
482495
483496 else :
484- self ._validate_properties (properties )
497+ self ._validate_properties (properties ,
498+ custom_formatters = custom_formatters )
485499
486500 if self .min_properties is not None :
487501 if self .min_properties < 0 :
@@ -512,7 +526,8 @@ def _validate_object(self, value):
512526
513527 return True
514528
515- def _validate_properties (self , value , one_of_schema = None ):
529+ def _validate_properties (self , value , one_of_schema = None ,
530+ custom_formatters = None ):
516531 all_props = self .get_all_properties ()
517532 all_props_names = self .get_all_properties_names ()
518533 all_req_props_names = self .get_all_required_properties_names ()
@@ -533,7 +548,7 @@ def _validate_properties(self, value, one_of_schema=None):
533548 for prop_name in extra_props :
534549 prop_value = value [prop_name ]
535550 self .additional_properties .validate (
536- prop_value )
551+ prop_value , custom_formatters = custom_formatters )
537552
538553 for prop_name , prop in iteritems (all_props ):
539554 try :
@@ -545,6 +560,6 @@ def _validate_properties(self, value, one_of_schema=None):
545560 if not prop .nullable and not prop .default :
546561 continue
547562 prop_value = prop .default
548- prop .validate (prop_value )
563+ prop .validate (prop_value , custom_formatters = custom_formatters )
549564
550565 return True
0 commit comments