@@ -268,6 +268,34 @@ def test_str(self):
268268 self .assertEqual (json .dumps (patch_obj ), patch .to_string ())
269269
270270
271+ def custom_types_dumps (obj ):
272+ def default (obj ):
273+ if isinstance (obj , decimal .Decimal ):
274+ return {'__decimal__' : str (obj )}
275+ raise TypeError ('Unknown type' )
276+
277+ return json .dumps (obj , default = default )
278+
279+
280+ def custom_types_loads (obj ):
281+ def as_decimal (dct ):
282+ if '__decimal__' in dct :
283+ return decimal .Decimal (dct ['__decimal__' ])
284+ return dct
285+
286+ return json .loads (obj , object_hook = as_decimal )
287+
288+
289+ class CustomTypesJsonPatch (jsonpatch .JsonPatch ):
290+ @staticmethod
291+ def json_dumper (obj ):
292+ return custom_types_dumps (obj )
293+
294+ @staticmethod
295+ def json_loader (obj ):
296+ return custom_types_loads (obj )
297+
298+
271299class MakePatchTestCase (unittest .TestCase ):
272300
273301 def test_apply_patch_to_copy (self ):
@@ -446,18 +474,33 @@ def test_issue103(self):
446474 self .assertEqual (res , dst )
447475 self .assertIsInstance (res ['A' ], float )
448476
449- def test_custom_types (self ):
450- def default (obj ):
451- if isinstance (obj , decimal .Decimal ):
452- return str (obj )
453- raise TypeError ('Unknown type' )
477+ def test_custom_types_diff (self ):
478+ old = {'value' : decimal .Decimal ('1.0' )}
479+ new = {'value' : decimal .Decimal ('1.00' )}
480+ generated_patch = jsonpatch .JsonPatch .from_diff (
481+ old , new , dumps = custom_types_dumps )
482+ str_patch = generated_patch .to_string (dumps = custom_types_dumps )
483+ loaded_patch = jsonpatch .JsonPatch .from_string (
484+ str_patch , loads = custom_types_loads )
485+ self .assertEqual (generated_patch , loaded_patch )
486+ new_from_patch = jsonpatch .apply_patch (old , generated_patch )
487+ self .assertEqual (new , new_from_patch )
454488
455- def dumps (obj ):
456- return json .dumps (obj , default = default )
489+ def test_custom_types_subclass (self ):
490+ old = {'value' : decimal .Decimal ('1.0' )}
491+ new = {'value' : decimal .Decimal ('1.00' )}
492+ generated_patch = CustomTypesJsonPatch .from_diff (old , new )
493+ str_patch = generated_patch .to_string ()
494+ loaded_patch = CustomTypesJsonPatch .from_string (str_patch )
495+ self .assertEqual (generated_patch , loaded_patch )
496+ new_from_patch = jsonpatch .apply_patch (old , loaded_patch )
497+ self .assertEqual (new , new_from_patch )
457498
499+ def test_custom_types_subclass_load (self ):
458500 old = {'value' : decimal .Decimal ('1.0' )}
459501 new = {'value' : decimal .Decimal ('1.00' )}
460- patch = jsonpatch .JsonPatch .from_diff (old , new , dumps = dumps )
502+ patch = CustomTypesJsonPatch .from_string (
503+ '[{"op": "replace", "path": "/value", "value": {"__decimal__": "1.00"}}]' )
461504 new_from_patch = jsonpatch .apply_patch (old , patch )
462505 self .assertEqual (new , new_from_patch )
463506
0 commit comments