33import six
44from six .moves import xrange
55from itertools import * # noqa
6+ from .exceptions import JSONPathError
67
78# Get logger name
89logger = logging .getLogger (__name__ )
1112# ... could be a kwarg pervasively but uses are rare and simple today
1213auto_id_field = None
1314
15+ NOT_SET = object ()
16+ LIST_KEY = object ()
17+
1418
1519class JSONPath (object ):
1620 """
@@ -27,6 +31,9 @@ def find(self, data):
2731 """
2832 raise NotImplementedError ()
2933
34+ def find_or_create (self , data ):
35+ return self .find (data )
36+
3037 def update (self , data , val ):
3138 """
3239 Returns `data` with the specified path replaced by `val`. Only updates
@@ -35,6 +42,9 @@ def update(self, data, val):
3542
3643 raise NotImplementedError ()
3744
45+ def update_or_create (self , data , val ):
46+ return self .update (data , val )
47+
3848 def filter (self , fn , data ):
3949 """
4050 Returns `data` with the specified path filtering nodes according
@@ -261,6 +271,23 @@ def update(self, data, val):
261271 self .right .update (datum .value , val )
262272 return data
263273
274+ def find_or_create (self , datum ):
275+ datum = DatumInContext .wrap (datum )
276+ submatches = []
277+ for subdata in self .left .find_or_create (datum ):
278+ if isinstance (subdata , AutoIdForDatum ):
279+ # Extra special case: auto ids do not have children,
280+ # so cut it off right now rather than auto id the auto id
281+ continue
282+ for submatch in self .right .find_or_create (subdata ):
283+ submatches .append (submatch )
284+ return submatches
285+
286+ def update_or_create (self , data , val ):
287+ for datum in self .left .find_or_create (data ):
288+ self .right .update_or_create (datum .value , val )
289+ return _clean_list_keys (data )
290+
264291 def filter (self , fn , data ):
265292 for datum in self .left .find (data ):
266293 self .right .filter (fn , datum .value )
@@ -497,15 +524,20 @@ class Fields(JSONPath):
497524 def __init__ (self , * fields ):
498525 self .fields = fields
499526
500- def get_field_datum (self , datum , field ):
527+ @staticmethod
528+ def get_field_datum (datum , field , create ):
501529 if field == auto_id_field :
502530 return AutoIdForDatum (datum )
503- else :
504- try :
505- field_value = datum .value [field ] # Do NOT use `val.get(field)` since that confuses None as a value and None due to `get`
506- return DatumInContext (value = field_value , path = Fields (field ), context = datum )
507- except (TypeError , KeyError , AttributeError ):
508- return None
531+ try :
532+ field_value = datum .value .get (field , NOT_SET )
533+ if field_value is NOT_SET :
534+ if create :
535+ datum .value [field ] = field_value = {}
536+ else :
537+ return None
538+ return DatumInContext (field_value , path = Fields (field ), context = datum )
539+ except (TypeError , AttributeError ):
540+ return None
509541
510542 def reified_fields (self , datum ):
511543 if '*' not in self .fields :
@@ -518,15 +550,28 @@ def reified_fields(self, datum):
518550 return ()
519551
520552 def find (self , datum ):
521- datum = DatumInContext . wrap (datum )
553+ return self . _find_base (datum , create = False )
522554
523- return [field_datum
524- for field_datum in [self .get_field_datum (datum , field ) for field in self .reified_fields (datum )]
525- if field_datum is not None ]
555+ def find_or_create (self , datum ):
556+ return self ._find_base (datum , create = True )
557+
558+ def _find_base (self , datum , create ):
559+ datum = DatumInContext .wrap (datum )
560+ field_data = [self .get_field_datum (datum , field , create )
561+ for field in self .reified_fields (datum )]
562+ return [fd for fd in field_data if fd is not None ]
526563
527564 def update (self , data , val ):
565+ return self ._update_base (data , val , create = False )
566+
567+ def update_or_create (self , data , val ):
568+ return self ._update_base (data , val , create = True )
569+
570+ def _update_base (self , data , val , create ):
528571 if data is not None :
529572 for field in self .reified_fields (DatumInContext .wrap (data )):
573+ if field not in data and create :
574+ data [field ] = {}
530575 if field in data :
531576 if hasattr (val , '__call__' ):
532577 val (data [field ], data , field )
@@ -565,14 +610,33 @@ def __init__(self, index):
565610 self .index = index
566611
567612 def find (self , datum ):
568- datum = DatumInContext .wrap (datum )
613+ return self ._find_base (datum , create = False )
614+
615+ def find_or_create (self , datum ):
616+ return self ._find_base (datum , create = True )
569617
618+ def _find_base (self , datum , create ):
619+ datum = DatumInContext .wrap (datum )
620+ if create :
621+ if datum .value == {}:
622+ datum .value = _create_list_key (datum .value )
623+ self ._pad_value (datum .value )
570624 if datum .value and len (datum .value ) > self .index :
571625 return [DatumInContext (datum .value [self .index ], path = self , context = datum )]
572626 else :
573627 return []
574628
575629 def update (self , data , val ):
630+ return self ._update_base (data , val , create = False )
631+
632+ def update_or_create (self , data , val ):
633+ return self ._update_base (data , val , create = True )
634+
635+ def _update_base (self , data , val , create ):
636+ if create :
637+ if data == {}:
638+ data = _create_list_key (data )
639+ self ._pad_value (data )
576640 if hasattr (val , '__call__' ):
577641 val .__call__ (data [self .index ], data , self .index )
578642 elif len (data ) > self .index :
@@ -590,6 +654,14 @@ def __eq__(self, other):
590654 def __str__ (self ):
591655 return '[%i]' % self .index
592656
657+ def __repr__ (self ):
658+ return '%s(index=%r)' % (self .__class__ .__name__ , self .index )
659+
660+ def _pad_value (self , value ):
661+ if len (value ) <= self .index :
662+ pad = self .index - len (value ) + 1
663+ value += [{} for __ in range (pad )]
664+
593665
594666class Slice (JSONPath ):
595667 """
@@ -668,3 +740,32 @@ def __repr__(self):
668740
669741 def __eq__ (self , other ):
670742 return isinstance (other , Slice ) and other .start == self .start and self .end == other .end and other .step == self .step
743+
744+
745+ def _create_list_key (dict_ ):
746+ """
747+ Adds a list to a dictionary by reference and returns the list.
748+
749+ See `_clean_list_keys()`
750+ """
751+ dict_ [LIST_KEY ] = new_list = [{}]
752+ return new_list
753+
754+
755+ def _clean_list_keys (dict_ ):
756+ """
757+ Replace {LIST_KEY: ['foo', 'bar']} with ['foo', 'bar'].
758+
759+ >>> _clean_list_keys({LIST_KEY: ['foo', 'bar']})
760+ ['foo', 'bar']
761+
762+ """
763+ for key , value in dict_ .items ():
764+ if isinstance (value , dict ):
765+ dict_ [key ] = _clean_list_keys (value )
766+ elif isinstance (value , list ):
767+ dict_ [key ] = [_clean_list_keys (v ) if isinstance (v , dict ) else v
768+ for v in value ]
769+ if LIST_KEY in dict_ :
770+ return dict_ [LIST_KEY ]
771+ return dict_
0 commit comments