Skip to content

Commit 498184e

Browse files
author
Julien Ruaux
committed
feat: Using FT.AGGREGATE for all queries. Resolves #17
1 parent ae873bf commit 498184e

11 files changed

+147
-290
lines changed

src/main/java/com/redis/trino/RediSearchBuiltinField.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package com.redis.trino;
22

33
import static com.google.common.collect.ImmutableMap.toImmutableMap;
4-
import static io.trino.spi.type.RealType.REAL;
54
import static io.trino.spi.type.VarcharType.VARCHAR;
65
import static java.util.Arrays.stream;
76
import static java.util.function.Function.identity;
@@ -16,7 +15,7 @@
1615

1716
enum RediSearchBuiltinField {
1817

19-
ID("_id", VARCHAR, Field.Type.TAG), SCORE("_score", REAL, Field.Type.NUMERIC);
18+
KEY("__key", VARCHAR, Field.Type.TAG);
2019

2120
private static final Map<String, RediSearchBuiltinField> COLUMNS_BY_NAME = stream(values())
2221
.collect(toImmutableMap(RediSearchBuiltinField::getName, identity()));
@@ -58,4 +57,8 @@ public ColumnMetadata getMetadata() {
5857
public RediSearchColumnHandle getColumnHandle() {
5958
return new RediSearchColumnHandle(name, type, fieldType, true, false);
6059
}
60+
61+
public static boolean isKeyColumn(String columnName) {
62+
return KEY.name.equals(columnName);
63+
}
6164
}

src/main/java/com/redis/trino/RediSearchMetadata.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
import com.google.common.collect.ImmutableMap;
5050
import com.redis.lettucemod.search.Field;
5151
import com.redis.lettucemod.search.querybuilder.Values;
52-
import com.redis.trino.RediSearchTableHandle.Type;
5352

5453
import io.airlift.log.Logger;
5554
import io.airlift.slice.Slice;
@@ -250,7 +249,7 @@ public Optional<ConnectorOutputMetadata> finishInsert(ConnectorSession session,
250249
@Override
251250
public RediSearchColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session,
252251
ConnectorTableHandle tableHandle) {
253-
return RediSearchBuiltinField.ID.getColumnHandle();
252+
return RediSearchBuiltinField.KEY.getColumnHandle();
254253
}
255254

256255
@Override
@@ -268,16 +267,16 @@ public void finishDelete(ConnectorSession session, ConnectorTableHandle tableHan
268267
@Override
269268
public RediSearchColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle,
270269
List<ColumnHandle> updatedColumns) {
271-
return RediSearchBuiltinField.ID.getColumnHandle();
270+
return RediSearchBuiltinField.KEY.getColumnHandle();
272271
}
273272

274273
@Override
275274
public RediSearchTableHandle beginUpdate(ConnectorSession session, ConnectorTableHandle tableHandle,
276275
List<ColumnHandle> updatedColumns, RetryMode retryMode) {
277276
checkRetry(retryMode);
278277
RediSearchTableHandle table = (RediSearchTableHandle) tableHandle;
279-
return new RediSearchTableHandle(table.getType(), table.getSchemaTableName(), table.getConstraint(),
280-
table.getLimit(), table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
278+
return new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(), table.getLimit(),
279+
table.getTermAggregations(), table.getMetricAggregations(), table.getWildcards(),
281280
updatedColumns.stream().map(RediSearchColumnHandle.class::cast).collect(toImmutableList()));
282281
}
283282

@@ -306,11 +305,9 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
306305
return Optional.empty();
307306
}
308307

309-
return Optional.of(new LimitApplicationResult<>(
310-
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
311-
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations(),
312-
handle.getWildcards(), handle.getUpdatedColumns()),
313-
true, false));
308+
return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getSchemaTableName(),
309+
handle.getConstraint(), OptionalLong.of(limit), handle.getTermAggregations(),
310+
handle.getMetricAggregations(), handle.getWildcards(), handle.getUpdatedColumns()), true, false));
314311
}
315312

316313
@Override
@@ -372,7 +369,7 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
372369
return Optional.empty();
373370
}
374371

375-
handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
372+
handle = new RediSearchTableHandle(handle.getSchemaTableName(), newDomain, handle.getLimit(),
376373
handle.getTermAggregations(), handle.getMetricAggregations(), newWildcards, handle.getUpdatedColumns());
377374

378375
return Optional.of(new ConstraintApplicationResult<>(handle, TupleDomain.withColumnDomains(unsupported),
@@ -498,9 +495,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
498495
if (aggregationList.isEmpty()) {
499496
return Optional.empty();
500497
}
501-
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
502-
table.getConstraint(), table.getLimit(), terms.build(), aggregationList, table.getWildcards(),
503-
table.getUpdatedColumns());
498+
RediSearchTableHandle tableHandle = new RediSearchTableHandle(table.getSchemaTableName(), table.getConstraint(),
499+
table.getLimit(), terms.build(), aggregationList, table.getWildcards(), table.getUpdatedColumns());
504500
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
505501
resultAssignments.build(), Map.of(), false));
506502
}

src/main/java/com/redis/trino/RediSearchPageSink.java

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,29 @@ public CompletableFuture<?> appendPage(Page page) {
100100
String prefix = prefix().orElse(schemaTableName.getTableName() + KEY_SEPARATOR);
101101
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
102102
connection.setAutoFlushCommands(false);
103-
RedisModulesAsyncCommands<String, String> commands = connection.async();
104-
List<RedisFuture<?>> futures = new ArrayList<>();
105-
for (int position = 0; position < page.getPositionCount(); position++) {
106-
String key = prefix + factory.create().toString();
107-
Map<String, String> map = new HashMap<>();
108-
for (int channel = 0; channel < page.getChannelCount(); channel++) {
109-
RediSearchColumnHandle column = columns.get(channel);
110-
Block block = page.getBlock(channel);
111-
if (block.isNull(position)) {
112-
continue;
103+
try {
104+
RedisModulesAsyncCommands<String, String> commands = connection.async();
105+
List<RedisFuture<?>> futures = new ArrayList<>();
106+
for (int position = 0; position < page.getPositionCount(); position++) {
107+
String key = prefix + factory.create().toString();
108+
Map<String, String> map = new HashMap<>();
109+
for (int channel = 0; channel < page.getChannelCount(); channel++) {
110+
RediSearchColumnHandle column = columns.get(channel);
111+
Block block = page.getBlock(channel);
112+
if (block.isNull(position)) {
113+
continue;
114+
}
115+
String value = value(column.getType(), block, position);
116+
map.put(column.getName(), value);
113117
}
114-
String value = value(column.getType(), block, position);
115-
map.put(column.getName(), value);
118+
RedisFuture<Long> future = commands.hset(key, map);
119+
futures.add(future);
116120
}
117-
RedisFuture<Long> future = commands.hset(key, map);
118-
futures.add(future);
121+
connection.flushCommands();
122+
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
123+
} finally {
124+
connection.setAutoFlushCommands(true);
119125
}
120-
connection.flushCommands();
121-
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
122-
connection.setAutoFlushCommands(true);
123126
return NOT_BLOCKED;
124127
}
125128

src/main/java/com/redis/trino/RediSearchPageSource.java

Lines changed: 93 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@
4141
import com.fasterxml.jackson.core.JsonGenerator;
4242
import com.redis.lettucemod.api.StatefulRedisModulesConnection;
4343
import com.redis.lettucemod.api.async.RedisModulesAsyncCommands;
44-
import com.redis.lettucemod.search.Document;
44+
import com.redis.lettucemod.search.AggregateWithCursorResults;
4545

46+
import io.airlift.log.Logger;
4647
import io.airlift.slice.Slice;
4748
import io.airlift.slice.SliceOutput;
4849
import io.lettuce.core.LettuceFutures;
@@ -57,28 +58,29 @@
5758

5859
public class RediSearchPageSource implements UpdatablePageSource {
5960

61+
private static final Logger log = Logger.get(RediSearchPageSource.class);
62+
6063
private static final int ROWS_PER_REQUEST = 1024;
6164

6265
private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
6366
private final RediSearchSession session;
6467
private final RediSearchTableHandle table;
65-
private final Iterator<Document<String, String>> cursor;
6668
private final String[] columnNames;
6769
private final List<Type> columnTypes;
68-
private final PageBuilder pageBuilder;
69-
70-
private Document<String, String> currentDoc;
70+
private final CursorIterator iterator;
71+
private Map<String, Object> currentDoc;
7172
private long count;
7273
private boolean finished;
7374

75+
private final PageBuilder pageBuilder;
76+
7477
public RediSearchPageSource(RediSearchSession session, RediSearchTableHandle table,
7578
List<RediSearchColumnHandle> columns) {
7679
this.session = session;
7780
this.table = table;
7881
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).toArray(String[]::new);
79-
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType)
80-
.collect(Collectors.toList());
81-
this.cursor = session.search(table, columnNames).iterator();
82+
this.iterator = new CursorIterator(session, table, columnNames);
83+
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(Collectors.toList());
8284
this.currentDoc = null;
8385
this.pageBuilder = new PageBuilder(columnTypes);
8486
}
@@ -108,26 +110,24 @@ public Page getNextPage() {
108110
verify(pageBuilder.isEmpty());
109111
count = 0;
110112
for (int i = 0; i < ROWS_PER_REQUEST; i++) {
111-
if (!cursor.hasNext()) {
113+
if (!iterator.hasNext()) {
112114
finished = true;
113115
break;
114116
}
115-
currentDoc = cursor.next();
117+
currentDoc = iterator.next();
116118
count++;
117119

118120
pageBuilder.declarePosition();
119121
for (int column = 0; column < columnTypes.size(); column++) {
120122
BlockBuilder output = pageBuilder.getBlockBuilder(column);
121-
String columnName = columnNames[column];
122-
String value = currentValue(columnName);
123+
Object value = currentValue(columnNames[column]);
123124
if (value == null) {
124125
output.appendNull();
125126
} else {
126-
writer.appendTo(columnTypes.get(column), value, output);
127+
writer.appendTo(columnTypes.get(column), value.toString(), output);
127128
}
128129
}
129130
}
130-
131131
Page page = pageBuilder.build();
132132
pageBuilder.reset();
133133
return page;
@@ -149,42 +149,33 @@ public void updateRows(Page page, List<Integer> columnValueAndRowIdChannels) {
149149
columnValueAndRowIdChannels.size() - 1);
150150
StatefulRedisModulesConnection<String, String> connection = session.getConnection();
151151
connection.setAutoFlushCommands(false);
152-
RedisModulesAsyncCommands<String, String> commands = connection.async();
153-
List<RedisFuture<?>> futures = new ArrayList<>();
154-
for (int position = 0; position < page.getPositionCount(); position++) {
155-
Block rowIdBlock = page.getBlock(rowIdChannel);
156-
if (rowIdBlock.isNull(position)) {
157-
continue;
158-
}
159-
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
160-
Map<String, String> map = new HashMap<>();
161-
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
162-
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
163-
Block block = page.getBlock(channel);
164-
if (block.isNull(position)) {
152+
try {
153+
RedisModulesAsyncCommands<String, String> commands = connection.async();
154+
List<RedisFuture<?>> futures = new ArrayList<>();
155+
for (int position = 0; position < page.getPositionCount(); position++) {
156+
Block rowIdBlock = page.getBlock(rowIdChannel);
157+
if (rowIdBlock.isNull(position)) {
165158
continue;
166159
}
167-
String value = RediSearchPageSink.value(column.getType(), block, position);
168-
map.put(column.getName(), value);
169-
}
170-
RedisFuture<Long> future = commands.hset(key, map);
171-
futures.add(future);
172-
}
173-
connection.flushCommands();
174-
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
175-
connection.setAutoFlushCommands(true);
176-
}
177-
178-
private String currentValue(String columnName) {
179-
if (RediSearchBuiltinField.isBuiltinColumn(columnName)) {
180-
if (RediSearchBuiltinField.ID.getName().equals(columnName)) {
181-
return currentDoc.getId();
182-
}
183-
if (RediSearchBuiltinField.SCORE.getName().equals(columnName)) {
184-
return String.valueOf(currentDoc.getScore());
160+
String key = VarcharType.VARCHAR.getSlice(rowIdBlock, position).toStringUtf8();
161+
Map<String, String> map = new HashMap<>();
162+
for (int channel = 0; channel < columnChannelMapping.size(); channel++) {
163+
RediSearchColumnHandle column = table.getUpdatedColumns().get(columnChannelMapping.get(channel));
164+
Block block = page.getBlock(channel);
165+
if (block.isNull(position)) {
166+
continue;
167+
}
168+
String value = RediSearchPageSink.value(column.getType(), block, position);
169+
map.put(column.getName(), value);
170+
}
171+
RedisFuture<Long> future = commands.hset(key, map);
172+
futures.add(future);
185173
}
174+
connection.flushCommands();
175+
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
176+
} finally {
177+
connection.setAutoFlushCommands(true);
186178
}
187-
return currentDoc.get(columnName);
188179
}
189180

190181
@Override
@@ -194,12 +185,67 @@ public CompletableFuture<Collection<Slice>> finish() {
194185
return future;
195186
}
196187

188+
private Object currentValue(String columnName) {
189+
if (RediSearchBuiltinField.isKeyColumn(columnName)) {
190+
return currentDoc.get(RediSearchBuiltinField.KEY.getName());
191+
}
192+
return currentDoc.get(columnName);
193+
}
194+
197195
public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException {
198196
return factory.createGenerator((OutputStream) output);
199197
}
200198

201199
@Override
202200
public void close() {
203-
// nothing to do
201+
try {
202+
iterator.close();
203+
} catch (Exception e) {
204+
log.error(e, "Could not close cursor iterator");
205+
}
206+
}
207+
208+
private static class CursorIterator implements Iterator<Map<String, Object>>, AutoCloseable {
209+
210+
private final RediSearchSession session;
211+
private final RediSearchTableHandle table;
212+
private Iterator<Map<String, Object>> iterator;
213+
private long cursor;
214+
215+
public CursorIterator(RediSearchSession session, RediSearchTableHandle table, String[] columnNames) {
216+
this.session = session;
217+
this.table = table;
218+
read(session.aggregate(table, columnNames));
219+
}
220+
221+
private void read(AggregateWithCursorResults<String> results) {
222+
this.iterator = results.iterator();
223+
this.cursor = results.getCursor();
224+
}
225+
226+
@Override
227+
public boolean hasNext() {
228+
while (!iterator.hasNext()) {
229+
if (cursor == 0) {
230+
return false;
231+
}
232+
read(session.cursorRead(table, cursor));
233+
}
234+
return true;
235+
}
236+
237+
@Override
238+
public Map<String, Object> next() {
239+
return iterator.next();
240+
}
241+
242+
@Override
243+
public void close() throws Exception {
244+
if (cursor == 0) {
245+
return;
246+
}
247+
session.cursorDelete(table, cursor);
248+
}
249+
204250
}
205251
}

0 commit comments

Comments
 (0)