Skip to content

Commit 3f756d9

Browse files
author
Julien Ruaux
committed
feat: Added filter clauses for RediSearch TEXT fields
1 parent 0a0754b commit 3f756d9

File tree

4 files changed

+56
-45
lines changed

4 files changed

+56
-45
lines changed

src/main/java/com/redis/trino/RediSearchMetadata.java

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -273,23 +273,6 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
273273
ConnectorTableHandle table, Constraint constraint) {
274274
RediSearchTableHandle handle = (RediSearchTableHandle) table;
275275

276-
Map<ColumnHandle, Domain> supported = new HashMap<>();
277-
Map<ColumnHandle, Domain> unsupported = new HashMap<>();
278-
Map<ColumnHandle, Domain> domains = constraint.getSummary().getDomains()
279-
.orElseThrow(() -> new IllegalArgumentException("constraint summary is NONE"));
280-
for (Map.Entry<ColumnHandle, Domain> entry : domains.entrySet()) {
281-
RediSearchColumnHandle column = (RediSearchColumnHandle) entry.getKey();
282-
283-
if (column.isSupportsPredicates()) {
284-
supported.put(column, entry.getValue());
285-
} else {
286-
unsupported.put(column, entry.getValue());
287-
}
288-
}
289-
290-
TupleDomain<ColumnHandle> oldDomain = handle.getConstraint();
291-
TupleDomain<ColumnHandle> newDomain = oldDomain.intersect(TupleDomain.withColumnDomains(supported));
292-
293276
ConnectorExpression oldExpression = constraint.getExpression();
294277
Map<String, String> newWildcards = new HashMap<>(handle.getWildcards());
295278
List<ConnectorExpression> expressions = ConnectorExpressions.extractConjuncts(constraint.getExpression());
@@ -319,6 +302,22 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
319302
notHandledExpressions.add(expression);
320303
}
321304

305+
Map<ColumnHandle, Domain> supported = new HashMap<>();
306+
Map<ColumnHandle, Domain> unsupported = new HashMap<>();
307+
Map<ColumnHandle, Domain> domains = constraint.getSummary().getDomains()
308+
.orElseThrow(() -> new IllegalArgumentException("constraint summary is NONE"));
309+
for (Map.Entry<ColumnHandle, Domain> entry : domains.entrySet()) {
310+
RediSearchColumnHandle column = (RediSearchColumnHandle) entry.getKey();
311+
312+
if (column.isSupportsPredicates() && !newWildcards.containsKey(column.getName())) {
313+
supported.put(column, entry.getValue());
314+
} else {
315+
unsupported.put(column, entry.getValue());
316+
}
317+
}
318+
319+
TupleDomain<ColumnHandle> oldDomain = handle.getConstraint();
320+
TupleDomain<ColumnHandle> newDomain = oldDomain.intersect(TupleDomain.withColumnDomains(supported));
322321
ConnectorExpression newExpression = ConnectorExpressions.and(notHandledExpressions);
323322
if (oldDomain.equals(newDomain) && oldExpression.equals(newExpression)) {
324323
return Optional.empty();

src/main/java/com/redis/trino/RediSearchQueryBuilder.java

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.google.common.primitives.Primitives;
4545
import com.google.common.primitives.Shorts;
4646
import com.google.common.primitives.SignedBytes;
47+
import com.redis.lettucemod.search.Field;
4748
import com.redis.lettucemod.search.Group;
4849
import com.redis.lettucemod.search.Reducer;
4950
import com.redis.lettucemod.search.Reducers.Avg;
@@ -78,14 +79,11 @@ public class RediSearchQueryBuilder {
7879
(alias, field) -> Avg.property(field).as(alias).build(), MetricAggregation.COUNT,
7980
(alias, field) -> Count.as(alias));
8081

81-
private RediSearchQueryBuilder() {
82-
}
83-
84-
public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain) {
82+
public String buildQuery(TupleDomain<ColumnHandle> tupleDomain) {
8583
return buildQuery(tupleDomain, Map.of());
8684
}
8785

88-
public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<String, String> wildcards) {
86+
public String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<String, String> wildcards) {
8987
List<Node> nodes = new ArrayList<>();
9088
Optional<Map<ColumnHandle, Domain>> domains = tupleDomain.getDomains();
9189
if (domains.isPresent()) {
@@ -94,7 +92,7 @@ public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<Strin
9492
Domain domain = entry.getValue();
9593
checkArgument(!domain.isNone(), "Unexpected NONE domain for %s", column.getName());
9694
if (!domain.isAll()) {
97-
buildPredicate(column.getName(), domain, column.getType()).ifPresent(nodes::add);
95+
buildPredicate(column, domain).ifPresent(nodes::add);
9896
}
9997
}
10098
}
@@ -107,8 +105,8 @@ public static String buildQuery(TupleDomain<ColumnHandle> tupleDomain, Map<Strin
107105
return QueryBuilder.intersect(nodes.toArray(new Node[0])).toString();
108106
}
109107

110-
private static Optional<Node> buildPredicate(String name, Domain domain, Type type) {
111-
String columnName = RedisModulesUtils.escapeTag(name);
108+
private Optional<Node> buildPredicate(RediSearchColumnHandle column, Domain domain) {
109+
String columnName = RedisModulesUtils.escapeTag(column.getName());
112110
checkArgument(domain.getType().isOrderable(), "Domain type must be orderable");
113111
if (domain.getValues().isNone()) {
114112
return Optional.empty();
@@ -120,21 +118,29 @@ private static Optional<Node> buildPredicate(String name, Domain domain, Type ty
120118
List<Node> disjuncts = new ArrayList<>();
121119
for (Range range : domain.getValues().getRanges().getOrderedRanges()) {
122120
if (range.isSingleValue()) {
123-
singleValues.add(translateValue(range.getSingleValue(), type));
121+
singleValues.add(translateValue(range.getSingleValue(), column.getType()));
124122
} else {
125123
List<Value> rangeConjuncts = new ArrayList<>();
126124
if (!range.isLowUnbounded()) {
127-
Object translated = translateValue(range.getLowBoundedValue(), type);
125+
Object translated = translateValue(range.getLowBoundedValue(), column.getType());
128126
if (translated instanceof Number numericValue) {
129127
double doubleValue = numericValue.doubleValue();
130128
rangeConjuncts.add(range.isLowInclusive() ? Values.ge(doubleValue) : Values.gt(doubleValue));
129+
} else {
130+
throw new UnsupportedOperationException(
131+
String.format("Range constraint not supported for type %s (column: '%s')",
132+
column.getType(), column.getName()));
131133
}
132134
}
133135
if (!range.isHighUnbounded()) {
134-
Object translated = translateValue(range.getHighBoundedValue(), type);
136+
Object translated = translateValue(range.getHighBoundedValue(), column.getType());
135137
if (translated instanceof Number numericValue) {
136138
double doubleValue = numericValue.doubleValue();
137139
rangeConjuncts.add(range.isHighInclusive() ? Values.le(doubleValue) : Values.lt(doubleValue));
140+
} else {
141+
throw new UnsupportedOperationException(
142+
String.format("Range constraint not supported for type %s (column: '%s')",
143+
column.getType(), column.getName()));
138144
}
139145
}
140146
// If conjuncts is null, then the range was ALL, which should already have been
@@ -145,17 +151,18 @@ private static Optional<Node> buildPredicate(String name, Domain domain, Type ty
145151
}
146152
}
147153
if (singleValues.size() == 1) {
148-
disjuncts.add(QueryBuilder.intersect(columnName, value(Iterables.getOnlyElement(singleValues), type)));
154+
disjuncts.add(QueryBuilder.intersect(columnName, value(Iterables.getOnlyElement(singleValues), column)));
149155
} else if (singleValues.size() > 1) {
150156
disjuncts.add(QueryBuilder.union(columnName,
151-
singleValues.stream().map(v -> value(v, type)).toArray(Value[]::new)));
157+
singleValues.stream().map(v -> value(v, column)).toArray(Value[]::new)));
152158
}
153159
return Optional.of(QueryBuilder.union(disjuncts.toArray(Node[]::new)));
154160
}
155161

156-
private static Value value(Object trinoNativeValue, Type type) {
162+
private Value value(Object trinoNativeValue, RediSearchColumnHandle column) {
157163
requireNonNull(trinoNativeValue, "trinoNativeValue is null");
158-
requireNonNull(type, "type is null");
164+
requireNonNull(column, "column is null");
165+
Type type = column.getType();
159166
if (type == DOUBLE) {
160167
return Values.eq((Double) trinoNativeValue);
161168
}
@@ -172,12 +179,15 @@ private static Value value(Object trinoNativeValue, Type type) {
172179
return Values.eq((Long) trinoNativeValue);
173180
}
174181
if (type instanceof VarcharType) {
175-
return Values.tags(RedisModulesUtils.escapeTag((String) trinoNativeValue));
182+
if (column.getFieldType() == Field.Type.TAG) {
183+
return Values.tags(RedisModulesUtils.escapeTag((String) trinoNativeValue));
184+
}
185+
return Values.value((String) trinoNativeValue);
176186
}
177187
throw new UnsupportedOperationException("Type " + type + " not supported");
178188
}
179189

180-
private static Object translateValue(Object trinoNativeValue, Type type) {
190+
private Object translateValue(Object trinoNativeValue, Type type) {
181191
requireNonNull(trinoNativeValue, "trinoNativeValue is null");
182192
requireNonNull(type, "type is null");
183193
checkArgument(Primitives.wrap(type.getJavaType()).isInstance(trinoNativeValue),
@@ -207,20 +217,20 @@ private static Object translateValue(Object trinoNativeValue, Type type) {
207217
throw new IllegalArgumentException("Unhandled type: " + type);
208218
}
209219

210-
private static Reducer reducer(MetricAggregation aggregation) {
220+
private Reducer reducer(MetricAggregation aggregation) {
211221
Optional<RediSearchColumnHandle> column = aggregation.getColumnHandle();
212222
String field = column.isPresent() ? column.get().getName() : null;
213223
return CONVERTERS.get(aggregation.getFunctionName()).apply(aggregation.getAlias(), field);
214224
}
215225

216-
public static Optional<Group> group(RediSearchTableHandle table) {
226+
public Optional<Group> group(RediSearchTableHandle table) {
217227
List<TermAggregation> terms = table.getTermAggregations();
218228
List<MetricAggregation> aggregates = table.getMetricAggregations();
219229
List<String> groupFields = new ArrayList<>();
220230
if (terms != null && !terms.isEmpty()) {
221231
groupFields = terms.stream().map(TermAggregation::getTerm).toList();
222232
}
223-
List<Reducer> reducers = aggregates.stream().map(RediSearchQueryBuilder::reducer).toList();
233+
List<Reducer> reducers = aggregates.stream().map(this::reducer).toList();
224234
if (reducers.isEmpty()) {
225235
return Optional.empty();
226236
}

src/main/java/com/redis/trino/RediSearchTranslator.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
public class RediSearchTranslator {
4242

43+
private final RediSearchQueryBuilder queryBuilder = new RediSearchQueryBuilder();
44+
4345
private final RediSearchConfig config;
4446

4547
public RediSearchTranslator(RediSearchConfig config) {
@@ -222,7 +224,7 @@ public Search build() {
222224

223225
public Search search(RediSearchTableHandle tableHandle, List<RediSearchColumnHandle> columns) {
224226
String index = index(tableHandle);
225-
String query = RediSearchQueryBuilder.buildQuery(tableHandle.getConstraint(), tableHandle.getWildcards());
227+
String query = queryBuilder.buildQuery(tableHandle.getConstraint(), tableHandle.getWildcards());
226228
Builder<String, String> options = SearchOptions.builder();
227229
options.limit(Limit.offset(0).num(limit(tableHandle)));
228230
options.returnFields(columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new));
@@ -231,9 +233,9 @@ public Search search(RediSearchTableHandle tableHandle, List<RediSearchColumnHan
231233

232234
public Aggregation aggregate(RediSearchTableHandle table) {
233235
String index = index(table);
234-
String query = RediSearchQueryBuilder.buildQuery(table.getConstraint(), table.getWildcards());
236+
String query = queryBuilder.buildQuery(table.getConstraint(), table.getWildcards());
235237
AggregateOptions.Builder<String, String> builder = AggregateOptions.builder();
236-
RediSearchQueryBuilder.group(table).ifPresent(builder::operation);
238+
queryBuilder.group(table).ifPresent(builder::operation);
237239
builder.operation(Limit.offset(0).num(limit(table)));
238240
AggregateOptions<String, String> options = builder.build();
239241
CursorOptions.Builder cursorOptions = CursorOptions.builder();

src/test/java/com/redis/trino/TestRediSearchQueryBuilder.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testBuildQuery() {
3333
ImmutableMap.of(COL1, Domain.create(ValueSet.ofRanges(range(BIGINT, 100L, false, 200L, true)), false),
3434
COL2, Domain.singleValue(createUnboundedVarcharType(), utf8Slice("a value"))));
3535

36-
String query = RediSearchQueryBuilder.buildQuery(tupleDomain);
36+
String query = new RediSearchQueryBuilder().buildQuery(tupleDomain);
3737
String expected = "((@col1:[(100.0 inf] @col1:[-inf 200.0]) @col2:{a\\ value})";
3838
assertEquals(query, expected);
3939
}
@@ -43,7 +43,7 @@ public void testBuildQueryIn() {
4343
TupleDomain<ColumnHandle> tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of(COL2,
4444
Domain.create(ValueSet.ofRanges(equal(createUnboundedVarcharType(), utf8Slice("hello")),
4545
equal(createUnboundedVarcharType(), utf8Slice("world"))), false)));
46-
String query = RediSearchQueryBuilder.buildQuery(tupleDomain);
46+
String query = new RediSearchQueryBuilder().buildQuery(tupleDomain);
4747
String expected = "(@col2:{world}|@col2:{hello})";
4848
assertEquals(query, expected);
4949

@@ -54,7 +54,7 @@ public void testBuildQueryOr() {
5454
TupleDomain<ColumnHandle> tupleDomain = TupleDomain.withColumnDomains(ImmutableMap.of(COL1,
5555
Domain.create(ValueSet.ofRanges(lessThan(BIGINT, 100L), greaterThan(BIGINT, 200L)), false)));
5656

57-
String query = RediSearchQueryBuilder.buildQuery(tupleDomain);
57+
String query = new RediSearchQueryBuilder().buildQuery(tupleDomain);
5858
String expected = "(@col1:[-inf (100.0]|@col1:[(200.0 inf])";
5959
assertEquals(query, expected);
6060
}
@@ -64,7 +64,7 @@ public void testBuildQueryNull() {
6464
TupleDomain<ColumnHandle> tupleDomain = TupleDomain.withColumnDomains(
6565
ImmutableMap.of(COL1, Domain.create(ValueSet.ofRanges(greaterThan(BIGINT, 200L)), true)));
6666

67-
String query = RediSearchQueryBuilder.buildQuery(tupleDomain);
67+
String query = new RediSearchQueryBuilder().buildQuery(tupleDomain);
6868
String expected = "@col1:[(200.0 inf]";
6969
assertEquals(query, expected);
7070
}
@@ -77,7 +77,7 @@ public void testBuildQueryInDouble() {
7777
equal(DoubleType.DOUBLE, 3.0));
7878
TupleDomain<ColumnHandle> tupleDomain = TupleDomain
7979
.withColumnDomains(ImmutableMap.of(orderkey, Domain.create(values, false)));
80-
String query = RediSearchQueryBuilder.buildQuery(tupleDomain);
80+
String query = new RediSearchQueryBuilder().buildQuery(tupleDomain);
8181
String expected = "(@orderkey:[1.0 1.0]|@orderkey:[2.0 2.0]|@orderkey:[3.0 3.0])";
8282
assertEquals(query, expected);
8383
}

0 commit comments

Comments
 (0)