@@ -236,6 +236,12 @@ def negate(self):
236236 def __hash__ (self ):
237237 return hash (self .operator )
238238
239+ def __str__ (self ):
240+ return self .operator
241+
242+ def __repr__ (self ):
243+ return self .operator
244+
239245
240246class SearchCombinable :
241247 def _combine (self , other , connector ):
@@ -288,12 +294,12 @@ def _get_query_index(self, fields, compiler):
288294 return search_indexes ["name" ]
289295 return "default"
290296
291- def search_operator (self , compiler , connection ):
297+ def search_operator (self ):
292298 raise NotImplementedError
293299
294300 def as_mql (self , compiler , connection ):
295301 index = self ._get_query_index (self .get_search_fields (), compiler )
296- return {"$search" : {** self .search_operator (compiler , connection ), "index" : index }}
302+ return {"$search" : {** self .search_operator (), "index" : index }}
297303
298304
299305class SearchAutocomplete (SearchExpression ):
@@ -307,7 +313,7 @@ def __init__(self, path, query, fuzzy=None, score=None):
307313 def get_search_fields (self ):
308314 return {self .path }
309315
310- def search_operator (self , compiler , connection ):
316+ def search_operator (self ):
311317 params = {
312318 "path" : self .path ,
313319 "query" : self .query ,
@@ -329,7 +335,7 @@ def __init__(self, path, value, score=None):
329335 def get_search_fields (self ):
330336 return {self .path }
331337
332- def search_operator (self , compiler , connection ):
338+ def search_operator (self ):
333339 params = {
334340 "path" : self .path ,
335341 "value" : self .value ,
@@ -348,7 +354,7 @@ def __init__(self, path, score=None):
348354 def get_search_fields (self ):
349355 return {self .path }
350356
351- def search_operator (self , compiler , connection ):
357+ def search_operator (self ):
352358 params = {
353359 "path" : self .path ,
354360 }
@@ -367,7 +373,7 @@ def __init__(self, path, value, score=None):
367373 def get_search_fields (self ):
368374 return {self .path }
369375
370- def search_operator (self , compiler , connection ):
376+ def search_operator (self ):
371377 params = {
372378 "path" : self .path ,
373379 "value" : self .value ,
@@ -389,7 +395,7 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None):
389395 def get_search_fields (self ):
390396 return {self .path }
391397
392- def search_operator (self , compiler , connection ):
398+ def search_operator (self ):
393399 params = {
394400 "path" : self .path ,
395401 "query" : self .query ,
@@ -413,7 +419,7 @@ def __init__(self, path, query, score=None):
413419 def get_search_fields (self ):
414420 return {self .path }
415421
416- def search_operator (self , compiler , connection ):
422+ def search_operator (self ):
417423 params = {
418424 "defaultPath" : self .path ,
419425 "query" : self .query ,
@@ -436,7 +442,7 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None):
436442 def get_search_fields (self ):
437443 return {self .path }
438444
439- def search_operator (self , compiler , connection ):
445+ def search_operator (self ):
440446 params = {
441447 "path" : self .path ,
442448 }
@@ -464,7 +470,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
464470 def get_search_fields (self ):
465471 return {self .path }
466472
467- def search_operator (self , compiler , connection ):
473+ def search_operator (self ):
468474 params = {
469475 "path" : self .path ,
470476 "query" : self .query ,
@@ -489,7 +495,7 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None,
489495 def get_search_fields (self ):
490496 return {self .path }
491497
492- def search_operator (self , compiler , connection ):
498+ def search_operator (self ):
493499 params = {
494500 "path" : self .path ,
495501 "query" : self .query ,
@@ -516,7 +522,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
516522 def get_search_fields (self ):
517523 return {self .path }
518524
519- def search_operator (self , compiler , connection ):
525+ def search_operator (self ):
520526 params = {
521527 "path" : self .path ,
522528 "query" : self .query ,
@@ -539,7 +545,7 @@ def __init__(self, path, relation, geometry, score=None):
539545 def get_search_fields (self ):
540546 return {self .path }
541547
542- def search_operator (self , compiler , connection ):
548+ def search_operator (self ):
543549 params = {
544550 "path" : self .path ,
545551 "relation" : self .relation ,
@@ -558,7 +564,7 @@ def __init__(self, path, kind, geo_object, score=None):
558564 self .score = score
559565 super ().__init__ ()
560566
561- def search_operator (self , compiler , connection ):
567+ def search_operator (self ):
562568 params = {
563569 "path" : self .path ,
564570 self .kind : self .geo_object ,
@@ -577,7 +583,7 @@ def __init__(self, documents, score=None):
577583 self .score = score
578584 super ().__init__ ()
579585
580- def search_operator (self , compiler , connection ):
586+ def search_operator (self ):
581587 params = {
582588 "like" : self .documents ,
583589 }
@@ -670,29 +676,23 @@ def get_search_fields(self):
670676 fields .update (clause .get_search_fields ())
671677 return fields
672678
673- def search_operator (self , compiler , connection ):
679+ def search_operator (self ):
674680 params = {}
675681 if self .must :
676- params ["must" ] = [clause .search_operator (compiler , connection ) for clause in self .must ]
682+ params ["must" ] = [clause .search_operator () for clause in self .must ]
677683 if self .must_not :
678- params ["mustNot" ] = [
679- clause .search_operator (compiler , connection ) for clause in self .must_not
680- ]
684+ params ["mustNot" ] = [clause .search_operator () for clause in self .must_not ]
681685 if self .should :
682- params ["should" ] = [
683- clause .search_operator (compiler , connection ) for clause in self .should
684- ]
686+ params ["should" ] = [clause .search_operator () for clause in self .should ]
685687 if self .filter :
686- params ["filter" ] = [
687- clause .search_operator (compiler , connection ) for clause in self .filter
688- ]
688+ params ["filter" ] = [clause .search_operator () for clause in self .filter ]
689689 if self .minimum_should_match is not None :
690690 params ["minimumShouldMatch" ] = self .minimum_should_match
691691
692692 return {"compound" : params }
693693
694694 def negate (self ):
695- return CompoundExpression (must = self . must_not , must_not = self . must + self . filter )
695+ return CompoundExpression (must_not = [ self ] )
696696
697697
698698class CombinedSearchExpression (SearchExpression ):
@@ -702,7 +702,7 @@ def __init__(self, lhs, operator, rhs):
702702 self .rhs = rhs
703703
704704 @staticmethod
705- def _flatten (node , negated = False ):
705+ def resolve (node , negated = False ):
706706 if node is None :
707707 return None
708708 # Leaf, resolve the compoundExpression
@@ -711,25 +711,24 @@ def _flatten(node, negated=False):
711711 # Apply De Morgan's Laws.
712712 operator = node .operator .negate () if negated else node .operator
713713 negated = negated != (node .operator == Operator .NOT )
714- lhs_compound = node ._flatten (node .lhs , negated )
715- rhs_compound = node ._flatten (node .rhs , negated )
714+ lhs_compound = node .resolve (node .lhs , negated )
715+ rhs_compound = node .resolve (node .rhs , negated )
716716 if operator == Operator .OR :
717717 return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
718- if node .operator == Operator .AND :
719- return CompoundExpression (
720- must = lhs_compound .must + rhs_compound .must ,
721- must_not = lhs_compound .must_not + rhs_compound .must_not ,
722- should = lhs_compound .should + rhs_compound .should ,
723- filter = lhs_compound .filter + rhs_compound .filter ,
724- )
725- # it also can be written as:
726- # this way is more consistent with OR, but the above is shorter in the debug query.
727- # return CompoundExpression(must=[lhs_compound, rhs_compound])
718+ if operator == Operator .AND :
719+ # NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
720+ # return CompoundExpression(
721+ # must=lhs_compound.must + rhs_compound.must,
722+ # must_not=lhs_compound.must_not + rhs_compound.must_not,
723+ # should=lhs_compound.should + rhs_compound.should,
724+ # filter=lhs_compound.filter + rhs_compound.filter,
725+ # )
726+ return CompoundExpression (must = [lhs_compound , rhs_compound ])
728727 # not operator
729728 return lhs_compound
730729
731730 def as_mql (self , compiler , connection ):
732- expression = self ._flatten (self )
731+ expression = self .resolve (self )
733732 return expression .as_mql (compiler , connection )
734733
735734
0 commit comments