2525
2626import static com .google .common .base .Preconditions .checkState ;
2727import static com .google .common .base .Verify .verify ;
28+ import static com .google .common .base .Verify .verifyNotNull ;
29+ import static com .google .common .collect .ImmutableSet .toImmutableSet ;
30+ import static io .airlift .slice .SliceUtf8 .getCodePointAt ;
31+ import static io .trino .spi .StandardErrorCode .INVALID_FUNCTION_ARGUMENT ;
32+ import static io .trino .spi .expression .StandardFunctions .LIKE_FUNCTION_NAME ;
2833import static java .util .Objects .requireNonNull ;
2934
35+ import java .util .ArrayList ;
3036import java .util .Collection ;
37+ import java .util .HashMap ;
3138import java .util .List ;
3239import java .util .Map ;
3340import java .util .Optional ;
3441import java .util .OptionalLong ;
42+ import java .util .Set ;
3543import java .util .concurrent .atomic .AtomicReference ;
44+ import java .util .stream .IntStream ;
3645
3746import com .google .common .collect .ImmutableList ;
3847import com .google .common .collect .ImmutableMap ;
48+ import com .redis .lettucemod .search .Field ;
49+ import com .redis .lettucemod .search .querybuilder .Values ;
3950import com .redis .trino .RediSearchTableHandle .Type ;
4051
4152import io .airlift .log .Logger ;
4253import io .airlift .slice .Slice ;
54+ import io .trino .plugin .base .expression .ConnectorExpressions ;
4355import io .trino .spi .StandardErrorCode ;
4456import io .trino .spi .TrinoException ;
4557import io .trino .spi .connector .AggregateFunction ;
6476import io .trino .spi .connector .SchemaTableName ;
6577import io .trino .spi .connector .SchemaTablePrefix ;
6678import io .trino .spi .connector .TableNotFoundException ;
79+ import io .trino .spi .expression .Call ;
6780import io .trino .spi .expression .ConnectorExpression ;
81+ import io .trino .spi .expression .Constant ;
6882import io .trino .spi .expression .Variable ;
83+ import io .trino .spi .predicate .Domain ;
6984import io .trino .spi .predicate .TupleDomain ;
7085import io .trino .spi .statistics .ComputedStatistics ;
7186
@@ -74,6 +89,9 @@ public class RediSearchMetadata implements ConnectorMetadata {
7489 private static final Logger log = Logger .get (RediSearchMetadata .class );
7590
7691 private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "syntheticColumn" ;
92+ private static final Set <Integer > REDISEARCH_RESERVED_CHARACTERS = IntStream
93+ .of ('?' , '*' , '|' , '{' , '}' , '[' , ']' , '(' , ')' , '"' , '#' , '@' , '&' , '<' , '>' , '~' ).boxed ()
94+ .collect (toImmutableSet ());
7795
7896 private final RediSearchSession rediSearchSession ;
7997 private final String schemaName ;
@@ -245,27 +263,149 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
245263 return Optional .empty ();
246264 }
247265
248- return Optional .of (new LimitApplicationResult <>(
249- new RediSearchTableHandle (handle .getType (), handle .getSchemaTableName (), handle .getConstraint (),
250- OptionalLong .of (limit ), handle .getTermAggregations (), handle .getMetricAggregations ()),
251- true , false ));
266+ return Optional .of (new LimitApplicationResult <>(new RediSearchTableHandle (handle .getType (),
267+ handle .getSchemaTableName (), handle .getConstraint (), OptionalLong .of (limit ),
268+ handle .getTermAggregations (), handle .getMetricAggregations (), handle .getWildcards ()), true , false ));
252269 }
253270
254271 @ Override
255272 public Optional <ConstraintApplicationResult <ConnectorTableHandle >> applyFilter (ConnectorSession session ,
256273 ConnectorTableHandle table , Constraint constraint ) {
257274 RediSearchTableHandle handle = (RediSearchTableHandle ) table ;
258275
276+ Map <ColumnHandle , Domain > supported = new HashMap <>();
277+ Map <ColumnHandle , Domain > unsupported = new HashMap <>();
278+ Map <ColumnHandle , Domain > domains = constraint .getSummary ().getDomains ()
279+ .orElseThrow (() -> new IllegalArgumentException ("constraint summary is NONE" ));
280+ for (Map .Entry <ColumnHandle , Domain > entry : domains .entrySet ()) {
281+ RediSearchColumnHandle column = (RediSearchColumnHandle ) entry .getKey ();
282+
283+ if (column .isSupportsPredicates ()) {
284+ supported .put (column , entry .getValue ());
285+ } else {
286+ unsupported .put (column , entry .getValue ());
287+ }
288+ }
289+
259290 TupleDomain <ColumnHandle > oldDomain = handle .getConstraint ();
260- TupleDomain <ColumnHandle > newDomain = oldDomain .intersect (constraint .getSummary ());
261- if (oldDomain .equals (newDomain )) {
291+ TupleDomain <ColumnHandle > newDomain = oldDomain .intersect (TupleDomain .withColumnDomains (supported ));
292+
293+ ConnectorExpression oldExpression = constraint .getExpression ();
294+ Map <String , String > newWildcards = new HashMap <>(handle .getWildcards ());
295+ List <ConnectorExpression > expressions = ConnectorExpressions .extractConjuncts (constraint .getExpression ());
296+ List <ConnectorExpression > notHandledExpressions = new ArrayList <>();
297+ for (ConnectorExpression expression : expressions ) {
298+ if (expression instanceof Call call && isSupportedLikeCall (call )) {
299+ List <ConnectorExpression > arguments = call .getArguments ();
300+ String variableName = ((Variable ) arguments .get (0 )).getName ();
301+ RediSearchColumnHandle column = (RediSearchColumnHandle ) constraint .getAssignments ().get (variableName );
302+ verifyNotNull (column , "No assignment for %s" , variableName );
303+ String columnName = column .getName ();
304+ Object pattern = ((Constant ) arguments .get (1 )).getValue ();
305+ Optional <Slice > escape = Optional .empty ();
306+ if (arguments .size () == 3 ) {
307+ escape = Optional .of ((Slice ) (((Constant ) arguments .get (2 )).getValue ()));
308+ }
309+
310+ if (!newWildcards .containsKey (columnName ) && pattern instanceof Slice slice ) {
311+ String wildcard = likeToWildcard (slice , escape );
312+ if (column .getFieldType () == Field .Type .TAG ) {
313+ wildcard = Values .tags (wildcard ).toString ();
314+ }
315+ newWildcards .put (columnName , wildcard );
316+ continue ;
317+ }
318+ }
319+ notHandledExpressions .add (expression );
320+ }
321+
322+ ConnectorExpression newExpression = ConnectorExpressions .and (notHandledExpressions );
323+ if (oldDomain .equals (newDomain ) && oldExpression .equals (newExpression )) {
262324 return Optional .empty ();
263325 }
264326
265327 handle = new RediSearchTableHandle (handle .getType (), handle .getSchemaTableName (), newDomain , handle .getLimit (),
266- handle .getTermAggregations (), handle .getMetricAggregations ());
328+ handle .getTermAggregations (), handle .getMetricAggregations (), newWildcards );
329+
330+ return Optional .of (new ConstraintApplicationResult <>(handle , TupleDomain .withColumnDomains (unsupported ),
331+ newExpression , false ));
332+
333+ }
334+
335+ protected static boolean isSupportedLikeCall (Call call ) {
336+ if (!LIKE_FUNCTION_NAME .equals (call .getFunctionName ())) {
337+ return false ;
338+ }
339+
340+ List <ConnectorExpression > arguments = call .getArguments ();
341+ if (arguments .size () < 2 || arguments .size () > 3 ) {
342+ return false ;
343+ }
344+
345+ if (!(arguments .get (0 ) instanceof Variable ) || !(arguments .get (1 ) instanceof Constant )) {
346+ return false ;
347+ }
348+
349+ if (arguments .size () == 3 ) {
350+ return arguments .get (2 ) instanceof Constant ;
351+ }
352+
353+ return true ;
354+ }
355+
356+ private static char getEscapeChar (Slice escape ) {
357+ String escapeString = escape .toStringUtf8 ();
358+ if (escapeString .length () == 1 ) {
359+ return escapeString .charAt (0 );
360+ }
361+ throw new TrinoException (INVALID_FUNCTION_ARGUMENT , "Escape string must be a single character" );
362+ }
363+
364+ protected static String likeToWildcard (Slice pattern , Optional <Slice > escape ) {
365+ Optional <Character > escapeChar = escape .map (RediSearchMetadata ::getEscapeChar );
366+ StringBuilder wildcard = new StringBuilder ();
367+ boolean escaped = false ;
368+ int position = 0 ;
369+ while (position < pattern .length ()) {
370+ int currentChar = getCodePointAt (pattern , position );
371+ position += 1 ;
372+ checkEscape (!escaped || currentChar == '%' || currentChar == '_' || currentChar == escapeChar .get ());
373+ if (!escaped && escapeChar .isPresent () && currentChar == escapeChar .get ()) {
374+ escaped = true ;
375+ } else {
376+ switch (currentChar ) {
377+ case '%' :
378+ wildcard .append (escaped ? "%" : "*" );
379+ escaped = false ;
380+ break ;
381+ case '_' :
382+ wildcard .append (escaped ? "_" : "?" );
383+ escaped = false ;
384+ break ;
385+ case '\\' :
386+ wildcard .append ("\\ \\ " );
387+ break ;
388+ default :
389+ // escape special RediSearch characters
390+ if (REDISEARCH_RESERVED_CHARACTERS .contains (currentChar )) {
391+ wildcard .append ('\\' );
392+ }
393+
394+ wildcard .appendCodePoint (currentChar );
395+ escaped = false ;
396+ }
397+ }
398+ }
267399
268- return Optional .of (new ConstraintApplicationResult <>(handle , constraint .getSummary (), false ));
400+ checkEscape (!escaped );
401+ return wildcard .toString ();
402+ }
403+
404+ private static void checkEscape (boolean condition ) {
405+ if (!condition ) {
406+ throw new TrinoException (INVALID_FUNCTION_ARGUMENT ,
407+ "Escape character must be followed by '%', '_' or the escape character itself" );
408+ }
269409 }
270410
271411 @ Override
@@ -291,7 +431,9 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
291431 if (metricAggregation .isEmpty ()) {
292432 return Optional .empty ();
293433 }
294- RediSearchColumnHandle newColumn = new RediSearchColumnHandle (colName , function .getOutputType (), false );
434+ io .trino .spi .type .Type outputType = function .getOutputType ();
435+ RediSearchColumnHandle newColumn = new RediSearchColumnHandle (colName , outputType ,
436+ RediSearchSession .toFieldType (outputType ), false , true );
295437 projections .add (new Variable (colName , function .getOutputType ()));
296438 resultAssignments .add (new Assignment (colName , newColumn , function .getOutputType ()));
297439 metricAggregations .add (metricAggregation .get ());
@@ -308,7 +450,7 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
308450 return Optional .empty ();
309451 }
310452 RediSearchTableHandle tableHandle = new RediSearchTableHandle (Type .AGGREGATE , table .getSchemaTableName (),
311- table .getConstraint (), table .getLimit (), termAggregations .build (), metrics );
453+ table .getConstraint (), table .getLimit (), termAggregations .build (), metrics , table . getWildcards () );
312454 return Optional .of (new AggregationApplicationResult <>(tableHandle , projections .build (),
313455 resultAssignments .build (), Map .of (), false ));
314456 }
@@ -325,7 +467,7 @@ public void rollback() {
325467 Optional .ofNullable (rollbackAction .getAndSet (null )).ifPresent (Runnable ::run );
326468 }
327469
328- private static SchemaTableName getTableName (ConnectorTableHandle tableHandle ) {
470+ private SchemaTableName getTableName (ConnectorTableHandle tableHandle ) {
329471 return ((RediSearchTableHandle ) tableHandle ).getSchemaTableName ();
330472 }
331473
@@ -338,8 +480,8 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema
338480 return new ConnectorTableMetadata (tableName , columns );
339481 }
340482
341- private static List <RediSearchColumnHandle > buildColumnHandles (ConnectorTableMetadata tableMetadata ) {
342- return tableMetadata .getColumns ().stream ()
343- . map ( m -> new RediSearchColumnHandle ( m . getName (), m . getType (), m .isHidden ())).toList ();
483+ private List <RediSearchColumnHandle > buildColumnHandles (ConnectorTableMetadata tableMetadata ) {
484+ return tableMetadata .getColumns ().stream (). map ( m -> new RediSearchColumnHandle ( m . getName (), m . getType (),
485+ RediSearchSession . toFieldType ( m . getType ()) , m .isHidden (), true )).toList ();
344486 }
345487}
0 commit comments