4444import com .google .common .primitives .Primitives ;
4545import com .google .common .primitives .Shorts ;
4646import com .google .common .primitives .SignedBytes ;
47+ import com .redis .lettucemod .search .Field ;
4748import com .redis .lettucemod .search .Group ;
4849import com .redis .lettucemod .search .Reducer ;
4950import com .redis .lettucemod .search .Reducers .Avg ;
@@ -78,14 +79,11 @@ public class RediSearchQueryBuilder {
7879 (alias , field ) -> Avg .property (field ).as (alias ).build (), MetricAggregation .COUNT ,
7980 (alias , field ) -> Count .as (alias ));
8081
81- private RediSearchQueryBuilder () {
82- }
83-
84- public static String buildQuery (TupleDomain <ColumnHandle > tupleDomain ) {
82+ public String buildQuery (TupleDomain <ColumnHandle > tupleDomain ) {
8583 return buildQuery (tupleDomain , Map .of ());
8684 }
8785
88- public static String buildQuery (TupleDomain <ColumnHandle > tupleDomain , Map <String , String > wildcards ) {
86+ public String buildQuery (TupleDomain <ColumnHandle > tupleDomain , Map <String , String > wildcards ) {
8987 List <Node > nodes = new ArrayList <>();
9088 Optional <Map <ColumnHandle , Domain >> domains = tupleDomain .getDomains ();
9189 if (domains .isPresent ()) {
@@ -94,7 +92,7 @@ public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<Strin
9492 Domain domain = entry .getValue ();
9593 checkArgument (!domain .isNone (), "Unexpected NONE domain for %s" , column .getName ());
9694 if (!domain .isAll ()) {
97- buildPredicate (column . getName () , domain , column . getType () ).ifPresent (nodes ::add );
95+ buildPredicate (column , domain ).ifPresent (nodes ::add );
9896 }
9997 }
10098 }
@@ -107,8 +105,8 @@ public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<Strin
107105 return QueryBuilder .intersect (nodes .toArray (new Node [0 ])).toString ();
108106 }
109107
110- private static Optional <Node > buildPredicate (String name , Domain domain , Type type ) {
111- String columnName = RedisModulesUtils .escapeTag (name );
108+ private Optional <Node > buildPredicate (RediSearchColumnHandle column , Domain domain ) {
109+ String columnName = RedisModulesUtils .escapeTag (column . getName () );
112110 checkArgument (domain .getType ().isOrderable (), "Domain type must be orderable" );
113111 if (domain .getValues ().isNone ()) {
114112 return Optional .empty ();
@@ -120,21 +118,29 @@ private static Optional<Node> buildPredicate(String name, Domain domain, Type ty
120118 List <Node > disjuncts = new ArrayList <>();
121119 for (Range range : domain .getValues ().getRanges ().getOrderedRanges ()) {
122120 if (range .isSingleValue ()) {
123- singleValues .add (translateValue (range .getSingleValue (), type ));
121+ singleValues .add (translateValue (range .getSingleValue (), column . getType () ));
124122 } else {
125123 List <Value > rangeConjuncts = new ArrayList <>();
126124 if (!range .isLowUnbounded ()) {
127- Object translated = translateValue (range .getLowBoundedValue (), type );
125+ Object translated = translateValue (range .getLowBoundedValue (), column . getType () );
128126 if (translated instanceof Number numericValue ) {
129127 double doubleValue = numericValue .doubleValue ();
130128 rangeConjuncts .add (range .isLowInclusive () ? Values .ge (doubleValue ) : Values .gt (doubleValue ));
129+ } else {
130+ throw new UnsupportedOperationException (
131+ String .format ("Range constraint not supported for type %s (column: '%s')" ,
132+ column .getType (), column .getName ()));
131133 }
132134 }
133135 if (!range .isHighUnbounded ()) {
134- Object translated = translateValue (range .getHighBoundedValue (), type );
136+ Object translated = translateValue (range .getHighBoundedValue (), column . getType () );
135137 if (translated instanceof Number numericValue ) {
136138 double doubleValue = numericValue .doubleValue ();
137139 rangeConjuncts .add (range .isHighInclusive () ? Values .le (doubleValue ) : Values .lt (doubleValue ));
140+ } else {
141+ throw new UnsupportedOperationException (
142+ String .format ("Range constraint not supported for type %s (column: '%s')" ,
143+ column .getType (), column .getName ()));
138144 }
139145 }
140146 // If conjuncts is null, then the range was ALL, which should already have been
@@ -145,17 +151,18 @@ private static Optional<Node> buildPredicate(String name, Domain domain, Type ty
145151 }
146152 }
147153 if (singleValues .size () == 1 ) {
148- disjuncts .add (QueryBuilder .intersect (columnName , value (Iterables .getOnlyElement (singleValues ), type )));
154+ disjuncts .add (QueryBuilder .intersect (columnName , value (Iterables .getOnlyElement (singleValues ), column )));
149155 } else if (singleValues .size () > 1 ) {
150156 disjuncts .add (QueryBuilder .union (columnName ,
151- singleValues .stream ().map (v -> value (v , type )).toArray (Value []::new )));
157+ singleValues .stream ().map (v -> value (v , column )).toArray (Value []::new )));
152158 }
153159 return Optional .of (QueryBuilder .union (disjuncts .toArray (Node []::new )));
154160 }
155161
156- private static Value value (Object trinoNativeValue , Type type ) {
162+ private Value value (Object trinoNativeValue , RediSearchColumnHandle column ) {
157163 requireNonNull (trinoNativeValue , "trinoNativeValue is null" );
158- requireNonNull (type , "type is null" );
164+ requireNonNull (column , "column is null" );
165+ Type type = column .getType ();
159166 if (type == DOUBLE ) {
160167 return Values .eq ((Double ) trinoNativeValue );
161168 }
@@ -172,12 +179,15 @@ private static Value value(Object trinoNativeValue, Type type) {
172179 return Values .eq ((Long ) trinoNativeValue );
173180 }
174181 if (type instanceof VarcharType ) {
175- return Values .tags (RedisModulesUtils .escapeTag ((String ) trinoNativeValue ));
182+ if (column .getFieldType () == Field .Type .TAG ) {
183+ return Values .tags (RedisModulesUtils .escapeTag ((String ) trinoNativeValue ));
184+ }
185+ return Values .value ((String ) trinoNativeValue );
176186 }
177187 throw new UnsupportedOperationException ("Type " + type + " not supported" );
178188 }
179189
180- private static Object translateValue (Object trinoNativeValue , Type type ) {
190+ private Object translateValue (Object trinoNativeValue , Type type ) {
181191 requireNonNull (trinoNativeValue , "trinoNativeValue is null" );
182192 requireNonNull (type , "type is null" );
183193 checkArgument (Primitives .wrap (type .getJavaType ()).isInstance (trinoNativeValue ),
@@ -207,20 +217,20 @@ private static Object translateValue(Object trinoNativeValue, Type type) {
207217 throw new IllegalArgumentException ("Unhandled type: " + type );
208218 }
209219
210- private static Reducer reducer (MetricAggregation aggregation ) {
220+ private Reducer reducer (MetricAggregation aggregation ) {
211221 Optional <RediSearchColumnHandle > column = aggregation .getColumnHandle ();
212222 String field = column .isPresent () ? column .get ().getName () : null ;
213223 return CONVERTERS .get (aggregation .getFunctionName ()).apply (aggregation .getAlias (), field );
214224 }
215225
216- public static Optional <Group > group (RediSearchTableHandle table ) {
226+ public Optional <Group > group (RediSearchTableHandle table ) {
217227 List <TermAggregation > terms = table .getTermAggregations ();
218228 List <MetricAggregation > aggregates = table .getMetricAggregations ();
219229 List <String > groupFields = new ArrayList <>();
220230 if (terms != null && !terms .isEmpty ()) {
221231 groupFields = terms .stream ().map (TermAggregation ::getTerm ).toList ();
222232 }
223- List <Reducer > reducers = aggregates .stream ().map (RediSearchQueryBuilder ::reducer ).toList ();
233+ List <Reducer > reducers = aggregates .stream ().map (this ::reducer ).toList ();
224234 if (reducers .isEmpty ()) {
225235 return Optional .empty ();
226236 }
0 commit comments