@@ -26,7 +26,11 @@ def find(self, data):
2626 raise NotImplementedError ()
2727
2828 def update (self , data , val ):
29- "Returns `data` with the specified path replaced by `val`"
29+ """
30+ Returns `data` with the specified path replaced by `val`. Only updates
31+ if the specified path exists.
32+ """
33+
3034 raise NotImplementedError ()
3135
3236 def child (self , child ):
@@ -227,6 +231,11 @@ def find(self, datum):
227231 if not isinstance (subdata , AutoIdForDatum )
228232 for submatch in self .right .find (subdata )]
229233
234+ def update (self , data , val ):
235+ for datum in self .left .find (data ):
236+ self .right .update (datum .value , val )
237+ return data
238+
230239 def __eq__ (self , other ):
231240 return isinstance (other , Child ) and self .left == other .left and self .right == other .right
232241
@@ -274,6 +283,11 @@ def __init__(self, left, right):
274283 def find (self , data ):
275284 return [subdata for subdata in self .left .find (data ) if self .right .find (subdata )]
276285
286+ def update (self , data , val ):
287+ for datum in self .find (data ):
288+ datum .path .update (data , val )
289+ return data
290+
277291 def __str__ (self ):
278292 return '%s where %s' % (self .left , self .right )
279293
@@ -329,6 +343,33 @@ def match_recursively(datum):
329343 def is_singular ():
330344 return False
331345
346+ def update (self , data , val ):
347+ # Get all left matches into a list
348+ left_matches = self .left .find (data )
349+ if not isinstance (left_matches , list ):
350+ left_matches = [left_matches ]
351+
352+ def update_recursively (data ):
353+ # Update only mutable values corresponding to JSON types
354+ if not (isinstance (data , list ) or isinstance (data , dict )):
355+ return
356+
357+ self .right .update (data , val )
358+
359+ # Manually do the * or [*] to avoid coercion and recurse just the right-hand pattern
360+ if isinstance (data , list ):
361+ for i in range (0 , len (data )):
362+ update_recursively (data [i ])
363+
364+ elif isinstance (data , dict ):
365+ for field in data .keys ():
366+ update_recursively (data [field ])
367+
368+ for submatch in left_matches :
369+ update_recursively (submatch .value )
370+
371+ return data
372+
332373 def __str__ (self ):
333374 return '%s..%s' % (self .left , self .right )
334375
@@ -415,6 +456,12 @@ def find(self, datum):
415456 for field_datum in [self .get_field_datum (datum , field ) for field in self .reified_fields (datum )]
416457 if field_datum is not None ]
417458
459+ def update (self , data , val ):
460+ for field in self .reified_fields (DatumInContext .wrap (data )):
461+ if field in data :
462+ data [field ] = val
463+ return data
464+
418465 def __str__ (self ):
419466 return ',' .join (map (str , self .fields ))
420467
@@ -445,6 +492,11 @@ def find(self, datum):
445492 else :
446493 return []
447494
495+ def update (self , data , val ):
496+ if len (data ) > self .index :
497+ data [self .index ] = val
498+ return data
499+
448500 def __eq__ (self , other ):
449501 return isinstance (other , Index ) and self .index == other .index
450502
@@ -495,6 +547,11 @@ def find(self, datum):
495547 else :
496548 return [DatumInContext (datum .value [i ], path = Index (i ), context = datum ) for i in range (0 , len (datum .value ))[self .start :self .end :self .step ]]
497549
550+ def update (self , data , val ):
551+ for datum in self .find (data ):
552+ datum .path .update (data , val )
553+ return data
554+
498555 def __str__ (self ):
499556 if self .start == None and self .end == None and self .step == None :
500557 return '[*]'
0 commit comments