Skip to content

Commit 6e80db1

Browse files
committed
feat: add support for contains_any() and contains_all(), (Resolves #82)
1 parent 5de87a7 commit 6e80db1

17 files changed

+501
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ By default, when using validation in this library, it will cap the complexity of
142142
For example if you were searching for (a=1 OR b=2) AND (c=3 OR d=4 OR e=5), we compute that there might be 6 index intersections needed, (a=1,c=3),(a=1,d=4),(a=1,e=5),... This provides a heuristic to cap costs and prevent
143143
runaway queries from being generated. It was actually intended that we look at the number of index scans needed, and maybe that's a closer measure to expense in the DB, but the math would only be slightly different.
144144

145-
Over time this value and argument might change as we get more experience, in the interm you can use 0 as a value to allow everything (say if the collection is small).
145+
Over time this value and argument might change as we get more experience, in the interim you can use 0 as a value to allow everything (say if the collection is small).
146146

147147
#### Regular Expressions
148148

external/epsearchast/v3/ast.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ type AstVisitor interface {
122122
VisitLike(astNode *AstNode) (bool, error)
123123
VisitILike(astNode *AstNode) (bool, error)
124124
VisitContains(astNode *AstNode) (bool, error)
125+
VisitContainsAny(astNode *AstNode) (bool, error)
126+
VisitContainsAll(astNode *AstNode) (bool, error)
125127
VisitText(astNode *AstNode) (bool, error)
126128
VisitIsNull(astNode *AstNode) (bool, error)
127129
}
@@ -171,6 +173,10 @@ func (a *AstNode) accept(v AstVisitor) error {
171173
descend, err = v.VisitILike(a)
172174
case "CONTAINS":
173175
descend, err = v.VisitContains(a)
176+
case "CONTAINS_ANY":
177+
descend, err = v.VisitContainsAny(a)
178+
case "CONTAINS_ALL":
179+
descend, err = v.VisitContainsAll(a)
174180
case "TEXT":
175181
descend, err = v.VisitText(a)
176182
case "IS_NULL":
@@ -232,7 +238,7 @@ func (a *AstNode) checkValid() error {
232238
if len(a.Children) < 2 {
233239
return fmt.Errorf("or should have at least two children")
234240
}
235-
case "IN":
241+
case "IN", "CONTAINS_ANY", "CONTAINS_ALL":
236242
if len(a.Children) > 0 {
237243
return fmt.Errorf("operator %v should not have any children", strings.ToLower(a.NodeType))
238244
}

external/epsearchast/v3/ast_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,40 @@ func TestValidObjectWithContainsReturnsAst(t *testing.T) {
217217
require.NotNil(t, astNode)
218218
}
219219

220+
func TestValidObjectWithContainsAnyReturnsAst(t *testing.T) {
221+
// Fixture Setup
222+
// language=JSON
223+
jsonTxt := `
224+
{
225+
"type": "CONTAINS_ANY",
226+
"args": [ "status", "paid", "pending", "failed"]
227+
}
228+
`
229+
// Execute SUT
230+
astNode, err := GetAst(jsonTxt)
231+
232+
// Verify
233+
require.NoError(t, err)
234+
require.NotNil(t, astNode)
235+
}
236+
237+
func TestValidObjectWithContainsAllReturnsAst(t *testing.T) {
238+
// Fixture Setup
239+
// language=JSON
240+
jsonTxt := `
241+
{
242+
"type": "CONTAINS_ALL",
243+
"args": [ "tags", "important", "urgent"]
244+
}
245+
`
246+
// Execute SUT
247+
astNode, err := GetAst(jsonTxt)
248+
249+
// Verify
250+
require.NoError(t, err)
251+
require.NotNil(t, astNode)
252+
}
253+
220254
func TestValidObjectWithTextReturnsAst(t *testing.T) {
221255
// Fixture Setup
222256
// language=JSON

external/epsearchast/v3/ast_visitor_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,66 @@ func TestPreAndPostAndInCalledOnAccept(t *testing.T) {
218218

219219
}
220220

221+
func TestPreAndPostAndContainsAnyCalledOnAccept(t *testing.T) {
222+
// Fixture Setup
223+
// language=JSON
224+
jsonTxt := `
225+
{
226+
"type": "CONTAINS_ANY",
227+
"args": [
228+
"tags",
229+
"important",
230+
"urgent"
231+
]
232+
}
233+
`
234+
235+
mockObj := new(MyMockedVisitor)
236+
mockObj.On("PreVisit").Return(nil).
237+
On("PostVisit").Return(nil).
238+
On("VisitContainsAny", mock.Anything).Return(true, nil)
239+
240+
astNode, err := GetAst(jsonTxt)
241+
require.NoError(t, err)
242+
243+
// Execute SUT
244+
err = astNode.Accept(mockObj)
245+
246+
// Verification
247+
require.NoError(t, err)
248+
249+
}
250+
251+
func TestPreAndPostAndContainsAllCalledOnAccept(t *testing.T) {
252+
// Fixture Setup
253+
// language=JSON
254+
jsonTxt := `
255+
{
256+
"type": "CONTAINS_ALL",
257+
"args": [
258+
"tags",
259+
"important",
260+
"urgent"
261+
]
262+
}
263+
`
264+
265+
mockObj := new(MyMockedVisitor)
266+
mockObj.On("PreVisit").Return(nil).
267+
On("PostVisit").Return(nil).
268+
On("VisitContainsAll", mock.Anything).Return(true, nil)
269+
270+
astNode, err := GetAst(jsonTxt)
271+
require.NoError(t, err)
272+
273+
// Execute SUT
274+
err = astNode.Accept(mockObj)
275+
276+
// Verification
277+
require.NoError(t, err)
278+
279+
}
280+
221281
func TestPreOnAcceptWithError(t *testing.T) {
222282
// Fixture Setup
223283
// language=JSON
@@ -808,6 +868,16 @@ func (m *MyMockedVisitor) VisitContains(astNode *AstNode) (bool, error) {
808868
return args.Bool(0), args.Error(1)
809869
}
810870

871+
func (m *MyMockedVisitor) VisitContainsAny(astNode *AstNode) (bool, error) {
872+
args := m.Called(astNode)
873+
return args.Bool(0), args.Error(1)
874+
}
875+
876+
func (m *MyMockedVisitor) VisitContainsAll(astNode *AstNode) (bool, error) {
877+
args := m.Called(astNode)
878+
return args.Bool(0), args.Error(1)
879+
}
880+
811881
func (m *MyMockedVisitor) VisitText(astNode *AstNode) (bool, error) {
812882
args := m.Called(astNode)
813883
return args.Bool(0), args.Error(1)

external/epsearchast/v3/es/es_query_builder.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,33 @@ func (d DefaultEsQueryBuilder) VisitContains(first, second string) (*JsonObject,
208208
return d.buildQueryWithBuilder(b, first, second)
209209
}
210210

211+
func (d DefaultEsQueryBuilder) VisitContainsAny(args ...string) (*JsonObject, error) {
212+
b := d.GetTermsQueryBuilderForArrayField()
213+
214+
return d.buildQueryWithBuilder(b, args...)
215+
}
216+
217+
func (d DefaultEsQueryBuilder) VisitContainsAll(args ...string) (*JsonObject, error) {
218+
// Build individual term queries for each value
219+
b := d.GetTermQueryBuilderForArrayField()
220+
221+
var termQueries []*JsonObject
222+
for _, value := range args[1:] {
223+
query, err := d.buildQueryWithBuilder(b, args[0], value)
224+
if err != nil {
225+
return nil, err
226+
}
227+
termQueries = append(termQueries, query)
228+
}
229+
230+
// Wrap in a bool query with must clause
231+
return &JsonObject{
232+
"bool": map[string]any{
233+
"must": termQueries,
234+
},
235+
}, nil
236+
}
237+
211238
func (d DefaultEsQueryBuilder) GetTermQueryBuilderForArrayField() func(args ...string) *JsonObject {
212239
return func(args ...string) *JsonObject {
213240
return &JsonObject{
@@ -218,6 +245,16 @@ func (d DefaultEsQueryBuilder) GetTermQueryBuilderForArrayField() func(args ...s
218245
}
219246
}
220247

248+
func (d DefaultEsQueryBuilder) GetTermsQueryBuilderForArrayField() func(args ...string) *JsonObject {
249+
return func(args ...string) *JsonObject {
250+
return &JsonObject{
251+
"terms": map[string]any{
252+
d.GetFieldMapping(args[0]).Array: args[1:],
253+
},
254+
}
255+
}
256+
}
257+
221258
func (d DefaultEsQueryBuilder) VisitText(first, second string) (*JsonObject, error) {
222259
b := d.BuildMatchBoolPrefixQuery()
223260

external/epsearchast/v3/es/es_query_builder_int_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,94 @@ func TestSmokeTestElasticSearchWithFilters(t *testing.T) {
935935
}`,
936936
count: 1,
937937
},
938+
{
939+
//language=JSON
940+
filter: `{
941+
"type": "CONTAINS_ANY",
942+
"args": ["array_field", "a", "c"]
943+
}`,
944+
count: 3,
945+
},
946+
{
947+
//language=JSON
948+
filter: `{
949+
"type": "CONTAINS_ANY",
950+
"args": ["array_field", "a", "d"]
951+
}`,
952+
count: 2,
953+
},
954+
{
955+
//language=JSON
956+
filter: `{
957+
"type": "CONTAINS_ANY",
958+
"args": ["array_field", "z"]
959+
}`,
960+
count: 0,
961+
},
962+
{
963+
//language=JSON
964+
filter: `{
965+
"type": "CONTAINS_ALL",
966+
"args": ["array_field", "a", "b"]
967+
}`,
968+
count: 1,
969+
},
970+
{
971+
//language=JSON
972+
filter: `{
973+
"type": "CONTAINS_ALL",
974+
"args": ["array_field", "c"]
975+
}`,
976+
count: 2,
977+
},
978+
{
979+
//language=JSON
980+
filter: `{
981+
"type": "CONTAINS_ALL",
982+
"args": ["array_field", "c", "d"]
983+
}`,
984+
count: 1,
985+
},
986+
{
987+
//language=JSON
988+
filter: `{
989+
"type": "CONTAINS_ALL",
990+
"args": ["array_field", "d", "c"]
991+
}`,
992+
count: 1,
993+
},
994+
{
995+
//language=JSON
996+
filter: `{
997+
"type": "CONTAINS_ALL",
998+
"args": ["array_field", "a", "c"]
999+
}`,
1000+
count: 0,
1001+
},
1002+
{
1003+
//language=JSON
1004+
filter: `{
1005+
"type": "CONTAINS_ANY",
1006+
"args": ["array_field", "d", "a"]
1007+
}`,
1008+
count: 2,
1009+
},
1010+
{
1011+
//language=JSON
1012+
filter: `{
1013+
"type": "CONTAINS_ANY",
1014+
"args": ["array_field", "a", "b", "c"]
1015+
}`,
1016+
count: 3,
1017+
},
1018+
{
1019+
//language=JSON
1020+
filter: `{
1021+
"type": "CONTAINS_ALL",
1022+
"args": ["array_field", "a", "b", "c"]
1023+
}`,
1024+
count: 0,
1025+
},
9381026
}
9391027

9401028
for _, tc := range testCases {

external/epsearchast/v3/gorm/gorm_query_builder.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package epsearchast_v3_gorm
33
import (
44
"fmt"
55
epsearchast_v3 "github.com/elasticpath/epcc-search-ast-helper/external/epsearchast/v3"
6+
"github.com/lib/pq"
67
"strings"
78
)
89

@@ -114,6 +115,20 @@ func (g DefaultGormQueryBuilder) VisitContains(first, second string) (*SubQuery,
114115
}, nil
115116
}
116117

118+
func (g DefaultGormQueryBuilder) VisitContainsAny(args ...string) (*SubQuery, error) {
119+
return &SubQuery{
120+
Clause: fmt.Sprintf("%s && ?", args[0]),
121+
Args: []interface{}{pq.Array(args[1:])},
122+
}, nil
123+
}
124+
125+
func (g DefaultGormQueryBuilder) VisitContainsAll(args ...string) (*SubQuery, error) {
126+
return &SubQuery{
127+
Clause: fmt.Sprintf("%s @> ?", args[0]),
128+
Args: []interface{}{pq.Array(args[1:])},
129+
}, nil
130+
}
131+
117132
func (g DefaultGormQueryBuilder) VisitText(first, second string) (*SubQuery, error) {
118133
return &SubQuery{
119134
Clause: fmt.Sprintf("to_tsvector('english', %s) @@ plainto_tsquery('english', ?)", first),

0 commit comments

Comments
 (0)