From 47f23bcb9cafe8fc0e893af7430d161c53b1fb13 Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Thu, 7 Aug 2025 16:58:23 -0700 Subject: [PATCH 1/8] start storing ids and scores --- src/main/knn/KnnGraphTester.java | 145 ++++++++++++++++++------------- src/main/knn/KnnTesterUtils.java | 8 +- src/python/knnPerfTest.py | 22 ++--- 3 files changed, 102 insertions(+), 73 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index ebdc67f2..7a2e977d 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -18,7 +18,11 @@ package knn; import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.OutputStream; +import java.io.Serializable; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; import java.nio.ByteBuffer; @@ -839,12 +843,12 @@ private void printHist(int[] hist, int max, int count, int nbuckets) { } @SuppressForbidden(reason = "Prints stuff") - private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, int[][] nn) + private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, ResultIds[][] nn) throws IOException { Result[] results = new Result[numQueryVectors]; - int[][] resultIds = new int[numQueryVectors][]; + ResultIds[][] resultIds = new ResultIds[numQueryVectors][]; long elapsedMS, totalCpuTimeMS, totalVisited = 0; - int topK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; + int annTopK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK; int fanout = (overSample > 1) ? (int) (this.fanout * overSample) : this.fanout; ExecutorService executorService; if (numSearchThread > 0) { @@ -860,7 +864,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat if (targetReader instanceof VectorReaderByte b) { targetReaderByte = b; } - log("searching " + numQueryVectors + " query vectors; topK=" + topK + ", fanout=" + fanout + "\n"); + log("searching " + numQueryVectors + " query vectors; ann-topK=" + annTopK + ", fanout=" + fanout + "\n"); long startNS; try (MMapDirectory dir = new MMapDirectory(indexPath)) { dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq")); @@ -874,10 +878,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } targetReader.reset(); @@ -886,10 +890,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat for (int i = 0; i < numQueryVectors; i++) { if (vectorEncoding.equals(VectorEncoding.BYTE)) { byte[] target = targetReaderByte.nextBytes(); - results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery); + results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery); } else { float[] target = targetReader.next(); - results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin); + results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin); } } ThreadDetails endThreadDetails = new ThreadDetails(); @@ -910,7 +914,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat StoredFields storedFields = reader.storedFields(); for (int i = 0; i < numQueryVectors; i++) { totalVisited += results[i].visitedCount(); - resultIds[i] = KnnTesterUtils.getResultIds(results[i].topDocs(), storedFields); + resultIds[i] = KnnTesterUtils.getResultIds(results[i].topDocs(), storedFields, this.topK); } log( "completed " @@ -930,15 +934,17 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat executorService.shutdown(); } } + // TODO: do we need to write nn here again? Didnt we already read it from some file? if (outputPath != null) { - ByteBuffer tmp = - ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(outputPath)) { - for (int i = 0; i < numQueryVectors; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); - } - } + writeExactNN(nn, outputPath); +// ByteBuffer tmp = +// ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); +// try (OutputStream out = Files.newOutputStream(outputPath)) { +// for (int i = 0; i < numQueryVectors; i++) { +// tmp.asIntBuffer().put(nn[i]); +// out.write(tmp.array()); +// } +// } } else { log("checking results\n"); float recall = checkResults(resultIds, nn); @@ -1019,29 +1025,32 @@ private static Result doKnnVectorQuery( record Result(TopDocs topDocs, long visitedCount, int reentryCount) { } - private float checkResults(int[][] results, int[][] nn) { + /** Holds ids and scores for corpus docs in search results */ + record ResultIds(int id, float score) implements Serializable {} + + private float checkResults(ResultIds[][] results, ResultIds[][] expected) { int totalMatches = 0; - int totalResults = results.length * topK; - for (int i = 0; i < results.length; i++) { + int totalResults = expected.length * topK; + for (int i = 0; i < expected.length; i++) { // System.out.println("compare " + Arrays.toString(nn[i]) + " to "); // System.out.println(Arrays.toString(results[i])); - totalMatches += compareNN(nn[i], results[i]); + totalMatches += compareNN(expected[i], results[i]); } return totalMatches / (float) totalResults; } - private int compareNN(int[] expected, int[] results) { + private int compareNN(ResultIds[] expected, ResultIds[] results) { int matched = 0; Set expectedSet = new HashSet<>(); Set alreadySeen = new HashSet<>(); for (int i = 0; i < topK; i++) { - expectedSet.add(expected[i]); + expectedSet.add(expected[i].id); } - for (int docId : results) { - if (alreadySeen.add(docId) == false) { - throw new IllegalStateException("duplicate docId=" + docId); + for (ResultIds r : results) { + if (alreadySeen.add(r.id) == false) { + throw new IllegalStateException("duplicate docId=" + r.id); } - if (expectedSet.contains(docId)) { + if (expectedSet.contains(r.id)) { ++matched; } } @@ -1053,7 +1062,7 @@ private int compareNN(int[] expected, int[] results) { * The method runs "numQueryVectors" target queries and returns "topK" nearest neighbors * for each of them. Nearest Neighbors are computed using exact match. */ - private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException { + private ResultIds[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException { // look in working directory for cached nn file String hash = Integer.toString(Objects.hash(docPath, indexPath, queryPath, numDocs, numQueryVectors, topK, similarityFunction.ordinal(), parentJoin, queryStartIndex, prefilter ? selectivity : 1f, prefilter ? randomSeed : 0f), 36); String nnFileName = "nn-" + hash + ".bin"; @@ -1066,7 +1075,8 @@ private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int que long startNS = System.nanoTime(); // TODO: enable computing NN from high precision vectors when // checking low-precision recall - int[][] nn; +// int[][] nn; + ResultIds[][] nn; if (vectorEncoding.equals(VectorEncoding.BYTE)) { nn = computeExactNNByte(queryPath, queryStartIndex); } else { @@ -1089,35 +1099,52 @@ private boolean isNewer(Path path, Path... others) throws IOException { return true; } - private int[][] readExactNN(Path nnPath) throws IOException { - int[][] result = new int[numQueryVectors][]; - try (FileChannel in = FileChannel.open(nnPath)) { - IntBuffer intBuffer = - in.map(FileChannel.MapMode.READ_ONLY, 0, numQueryVectors * topK * Integer.BYTES) - .order(ByteOrder.LITTLE_ENDIAN) - .asIntBuffer(); + private ResultIds[][] readExactNN(Path nnPath) throws IOException { + log("reading true nearest neighbors from file \"" + nnPath + "\"\n"); + ResultIds[][] nn = new ResultIds[numQueryVectors][]; + try (InputStream in = Files.newInputStream(nnPath); + ObjectInputStream ois = new ObjectInputStream(in)) { for (int i = 0; i < numQueryVectors; i++) { - result[i] = new int[topK]; - intBuffer.get(result[i]); + nn[i] = (ResultIds[]) ois.readObject(); } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); } - return result; +// +// try (FileChannel in = FileChannel.open(nnPath)) { +// IntBuffer intBuffer = +// in.map(FileChannel.MapMode.READ_ONLY, 0, numQueryVectors * topK * Integer.BYTES) +// .order(ByteOrder.LITTLE_ENDIAN) +// .asIntBuffer(); +// for (int i = 0; i < numQueryVectors; i++) { +// nn[i] = new int[topK]; +// intBuffer.get(nn[i]); +// } +// } + return nn; } - private void writeExactNN(int[][] nn, Path nnPath) throws IOException { - log("writing true nearest neighbors to cache file \"" + nnPath + "\"\n"); - ByteBuffer tmp = - ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); - try (OutputStream out = Files.newOutputStream(nnPath)) { + private void writeExactNN(ResultIds[][] nn, Path nnPath) throws IOException { + log("\nwriting true nearest neighbors to cache file \"" + nnPath + "\"\n"); + try (OutputStream fileOutputStream = Files.newOutputStream(nnPath); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream)) { for (int i = 0; i < numQueryVectors; i++) { - tmp.asIntBuffer().put(nn[i]); - out.write(tmp.array()); + objectOutputStream.writeObject(nn[i]); } } + +// ByteBuffer tmp = +// ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); +// try (OutputStream out = Files.newOutputStream(nnPath)) { +// for (int i = 0; i < numQueryVectors; i++) { +// tmp.asIntBuffer().put(nn[i]); +// out.write(tmp.array()); +// } +// } } - private int[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { - int[][] result = new int[numQueryVectors][]; + private ResultIds[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { + ResultIds[][] result = new ResultIds[numQueryVectors][]; log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n"); List tasks = new ArrayList<>(); try (MMapDirectory dir = new MMapDirectory(indexPath)) { @@ -1143,10 +1170,10 @@ class ComputeNNByteTask implements Callable { private final int queryOrd; private final byte[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader) { + ComputeNNByteTask(int queryOrd, byte[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; @@ -1164,7 +1191,7 @@ public Void call() { .add(filterQuery, BooleanClause.Occur.FILTER) .build(); var topDocs = searcher.search(query, topK); - result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields()); + result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields(), topK); if ((queryOrd + 1) % 10 == 0) { log(" " + (queryOrd + 1)); } @@ -1176,9 +1203,9 @@ public Void call() { } /** Brute force computation of "true" nearest neighhbors. */ - private int[][] computeExactNN(Path queryPath, int queryStartIndex) + private ResultIds[][] computeExactNN(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { - int[][] result = new int[numQueryVectors][]; + ResultIds[][] result = new ResultIds[numQueryVectors][]; log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n"); log("parentJoin = %s\n", parentJoin); try (MMapDirectory dir = new MMapDirectory(indexPath)) { @@ -1216,10 +1243,10 @@ class ComputeNNFloatTask implements Callable { private final int queryOrd; private final float[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) { + ComputeNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; @@ -1238,7 +1265,7 @@ public Void call() { .add(filterQuery, BooleanClause.Occur.FILTER) .build(); var topDocs = searcher.search(query, topK); - result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields()); + result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields(), topK); if ((queryOrd + 1) % 10 == 0) { log(" " + (queryOrd + 1)); } @@ -1255,10 +1282,10 @@ class ComputeExactSearchNNFloatTask implements Callable { private final int queryOrd; private final float[] query; - private final int[][] result; + private final ResultIds[][] result; private final IndexReader reader; - ComputeExactSearchNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) { + ComputeExactSearchNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) { this.queryOrd = queryOrd; this.query = query; this.result = result; @@ -1272,7 +1299,7 @@ public Void call() { ParentJoinBenchmarkQuery parentJoinQuery = new ParentJoinBenchmarkQuery(query, null, topK); TopDocs topHits = ParentJoinBenchmarkQuery.runExactSearch(reader, parentJoinQuery); StoredFields storedFields = reader.storedFields(); - result[queryOrd] = KnnTesterUtils.getResultIds(topHits, storedFields); + result[queryOrd] = KnnTesterUtils.getResultIds(topHits, storedFields, topK); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/knn/KnnTesterUtils.java b/src/main/knn/KnnTesterUtils.java index 15f6159b..ad6ff6b4 100644 --- a/src/main/knn/KnnTesterUtils.java +++ b/src/main/knn/KnnTesterUtils.java @@ -23,14 +23,16 @@ import java.io.IOException; +import static knn.KnnGraphTester.ID_FIELD; +import static knn.KnnGraphTester.ResultIds; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; public class KnnTesterUtils { /** Fetches values for the "id" field from search results */ - public static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOException { - int[] resultIds = new int[topDocs.scoreDocs.length]; + public static ResultIds[] getResultIds(TopDocs topDocs, StoredFields storedFields, int k) throws IOException { + ResultIds[] resultIds = new ResultIds[k]; int i = 0; // TODO: switch to doc values for this id field? more efficent than stored fields // TODO: or, at least load the stored documents in index (Lucene docid) order to @@ -39,7 +41,7 @@ public static int[] getResultIds(TopDocs topDocs, StoredFields storedFields) thr // queries have run) for (ScoreDoc doc : topDocs.scoreDocs) { assert doc.doc != NO_MORE_DOCS: "illegal docid " + doc.doc + " returned from KNN search?"; - resultIds[i++] = Integer.parseInt(storedFields.document(doc.doc).get(KnnGraphTester.ID_FIELD)); + resultIds[i++] = new ResultIds(Integer.parseInt(storedFields.document(doc.doc).get(ID_FIELD)), doc.score); } return resultIds; } diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index 374f5f5a..4f574e2e 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -52,16 +52,16 @@ #'ndoc': (10000, 100000, 200000, 500000), #'ndoc': (2_000_000,), #'ndoc': (1_000_000,), - "ndoc": (500_000,), - #'ndoc': (50_000,), - "maxConn": (32, 64, 96), +# "ndoc": (500_000,), + 'ndoc': (10_000,), +# "maxConn": (32, 64, 96), # "maxConn": (64,), - #'maxConn': (32,), - "beamWidthIndex": (250, 500), - # "beamWidthIndex": (250,), - #'beamWidthIndex': (50,), - "fanout": (20, 50, 100, 250), - # "fanout": (50,), + 'maxConn': (32,), +# "beamWidthIndex": (250, 500), +# "beamWidthIndex": (250,), + 'beamWidthIndex': (50,), +# "fanout": (20, 50, 100, 250), + "fanout": (50,), #'quantize': None, #'quantizeBits': (32, 7, 4), "numMergeWorker": (12,), @@ -75,8 +75,8 @@ #'quantize': (True,), "quantizeBits": ( 4, - 7, - 32, +# 7, +# 32, ), # "quantizeBits": (1,), # "overSample": (5,), # extra ratio of vectors to retrieve, for testing approximate scoring, e.g. quantized indices From 5c6eb39cc0fd7557ee4c6a730fa9bef6d308a903 Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Thu, 7 Aug 2025 22:57:39 -0700 Subject: [PATCH 2/8] ndcg working --- src/main/knn/KnnGraphTester.java | 46 +++++++++++++++++++++++++------- src/main/knn/KnnTesterUtils.java | 14 ++++++++++ src/python/knnPerfTest.py | 8 +++--- 3 files changed, 56 insertions(+), 12 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 7a2e977d..49cbd126 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -23,11 +23,6 @@ import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.Serializable; -import java.lang.management.ManagementFactory; -import java.lang.management.ThreadMXBean; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.IntBuffer; import java.nio.channels.FileChannel; import java.nio.file.Files; import java.nio.file.Path; @@ -37,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Deque; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -97,7 +93,6 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.search.join.CheckJoinIndex; import org.apache.lucene.store.Directory; @@ -947,7 +942,9 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat // } } else { log("checking results\n"); - float recall = checkResults(resultIds, nn); + float recall = checkRecall(resultIds, nn); + double ndcg10 = calculateNDCG(nn, resultIds, 10); + double ndcgK = calculateNDCG(nn, resultIds, topK); totalVisited /= numQueryVectors; String quantizeDesc; if (quantize) { @@ -958,8 +955,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat double reindexSec = reindexTimeMsec / 1000.0; System.out.printf( Locale.ROOT, - "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", + "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", recall, + ndcg10, + ndcgK, elapsedMS / (float) numQueryVectors, totalCpuTimeMS / (float) numQueryVectors, totalCpuTimeMS / (float) elapsedMS, @@ -1028,7 +1027,7 @@ record Result(TopDocs topDocs, long visitedCount, int reentryCount) { /** Holds ids and scores for corpus docs in search results */ record ResultIds(int id, float score) implements Serializable {} - private float checkResults(ResultIds[][] results, ResultIds[][] expected) { + private float checkRecall(ResultIds[][] results, ResultIds[][] expected) { int totalMatches = 0; int totalResults = expected.length * topK; for (int i = 0; i < expected.length; i++) { @@ -1039,6 +1038,35 @@ private float checkResults(ResultIds[][] results, ResultIds[][] expected) { return totalMatches / (float) totalResults; } + /** + * Calculates Normalized Discounted Cumulative Gain (NDCG) at K. + * + *

We use full precision vector similarity scores for relevance. Since actual + * knn search result may hold quantized scores, we use scores for the corresponding + * document "id" from {@code ideal} search results. If a document is not present + * in ideal, it is considered irrelevant, and we assign it a score of 0f. + */ + private double calculateNDCG(ResultIds[][] ideal, ResultIds[][] actual, int k) { + double ndcg = 0; + for (int i = 0; i < ideal.length; i++) { + float[] exactResultsRelevance = new float[ideal[i].length]; + HashMap idToRelevance = new HashMap(ideal[i].length); + for (int rank = 0; rank < ideal[i].length; rank++) { + exactResultsRelevance[rank] = ideal[i][rank].score(); + idToRelevance.put(ideal[i][rank].id(), ideal[i][rank].score()); + } + float[] actualResultsRelevance = new float[actual[i].length]; + for (int rank = 0; rank < actual[i].length; rank++) { + actualResultsRelevance[rank] = idToRelevance.getOrDefault(actual[i][rank].id(), 0f); + } + double idealDCG = KnnTesterUtils.dcg(exactResultsRelevance, k); + double actualDCG = KnnTesterUtils.dcg(actualResultsRelevance, k); + ndcg += (actualDCG / idealDCG); + } + ndcg /= ideal.length; + return ndcg; + } + private int compareNN(ResultIds[] expected, ResultIds[] results) { int matched = 0; Set expectedSet = new HashSet<>(); diff --git a/src/main/knn/KnnTesterUtils.java b/src/main/knn/KnnTesterUtils.java index ad6ff6b4..1934b40e 100644 --- a/src/main/knn/KnnTesterUtils.java +++ b/src/main/knn/KnnTesterUtils.java @@ -45,4 +45,18 @@ public static ResultIds[] getResultIds(TopDocs topDocs, StoredFields storedField } return resultIds; } + + /** + * Calculates Discounted Cumulative Gain @k + * @param relevance Relevance scores sorted by rank of search results. + * @param k DCG is calculated up to this rank + */ + public static double dcg(float[] relevance, int k) { + double dcg = 0; + k = Math.min(relevance.length, k); + for (int i = 0; i < k; i++) { + dcg += relevance[i] / (Math.log(2 + i) / Math.log(2)); // rank = (i+1) + } + return dcg; + } } diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index 4f574e2e..cc006995 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -53,7 +53,7 @@ #'ndoc': (2_000_000,), #'ndoc': (1_000_000,), # "ndoc": (500_000,), - 'ndoc': (10_000,), + 'ndoc': (10_000, 50_000, 100_000), # "maxConn": (32, 64, 96), # "maxConn": (64,), 'maxConn': (32,), @@ -75,8 +75,8 @@ #'quantize': (True,), "quantizeBits": ( 4, -# 7, -# 32, + 7, + 32, ), # "quantizeBits": (1,), # "overSample": (5,), # extra ratio of vectors to retrieve, for testing approximate scoring, e.g. quantized indices @@ -94,6 +94,8 @@ OUTPUT_HEADERS = [ "recall", + "ndcg@10", + "ndcg@K", "latency(ms)", "netCPU", "avgCpuCount", From a8296822c78b784f23a3f0cbe151d74f862b8d58 Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Thu, 7 Aug 2025 23:12:36 -0700 Subject: [PATCH 3/8] add rerank support --- src/main/knn/KnnGraphTester.java | 29 ++++++++++++++++++++--------- src/python/knnPerfTest.py | 8 ++++++-- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 49cbd126..16acbbb2 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -83,6 +83,8 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DoubleValuesSourceRescorer; +import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -179,6 +181,8 @@ enum IndexType { private IndexType indexType; // oversampling, e.g. the multiple * k to gather before checking recall private float overSample; + // rerank using full precision vectors + private boolean rerank; private KnnGraphTester() { // set defaults @@ -202,6 +206,7 @@ private KnnGraphTester() { queryStartIndex = 0; indexType = IndexType.HNSW; overSample = 1f; + rerank = false; } private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncoding vectorEncoding, boolean noisy) throws IOException { @@ -283,6 +288,9 @@ private void run(String... args) throws Exception { throw new IllegalArgumentException("-overSample must be >= 1"); } break; + case "-rerank": + rerank = true; + break; case "-fanout": if (iarg == args.length - 1) { throw new IllegalArgumentException("-fanout requires a following number"); @@ -932,14 +940,6 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat // TODO: do we need to write nn here again? Didnt we already read it from some file? if (outputPath != null) { writeExactNN(nn, outputPath); -// ByteBuffer tmp = -// ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); -// try (OutputStream out = Files.newOutputStream(outputPath)) { -// for (int i = 0; i < numQueryVectors; i++) { -// tmp.asIntBuffer().put(nn[i]); -// out.write(tmp.array()); -// } -// } } else { log("checking results\n"); float recall = checkRecall(resultIds, nn); @@ -1004,7 +1004,7 @@ private static Result doKnnByteVectorQuery( return new Result(docs, profiledQuery.totalVectorCount(), 0); } - private static Result doKnnVectorQuery( + private Result doKnnVectorQuery( IndexSearcher searcher, String field, float[] vector, int k, int fanout, boolean prefilter, Query filter, boolean isParentJoinQuery) throws IOException { if (isParentJoinQuery) { @@ -1018,6 +1018,17 @@ private static Result doKnnVectorQuery( .add(filter, BooleanClause.Occur.FILTER) .build(); TopDocs docs = searcher.search(query, k); + if (rerank) { + FullPrecisionFloatVectorSimilarityValuesSource valuesSource = new FullPrecisionFloatVectorSimilarityValuesSource(vector, field); + DoubleValuesSourceRescorer rescorer = new DoubleValuesSourceRescorer(valuesSource) { + @Override + protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) { + return valuePresent ? (float) sourceValue : firstPassScore; + } + }; + TopDocs rerankedDocs = rescorer.rescore(searcher, docs, topK); + return new Result(rerankedDocs, profiledQuery.totalVectorCount(), 0); + } return new Result(docs, profiledQuery.totalVectorCount(), 0); } diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index cc006995..a5590ec1 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -75,8 +75,8 @@ #'quantize': (True,), "quantizeBits": ( 4, - 7, - 32, +# 7, +# 32, ), # "quantizeBits": (1,), # "overSample": (5,), # extra ratio of vectors to retrieve, for testing approximate scoring, e.g. quantized indices @@ -89,6 +89,7 @@ "queryStartIndex": (0,), # seek to this start vector before searching, to sample different vectors # "forceMerge": (True, False), #'niter': (10,), + "rerank": (False, True), } @@ -282,6 +283,9 @@ def run_knn_benchmark(checkout, values): if "-indexType" in this_cmd and "flat" in this_cmd: skip_headers.add("maxConn") skip_headers.add("beamWidth") + if "-rerank" not in this_cmd: + skip_headers.add("ndcg@10") + skip_headers.add("ndcg@K") print_fixed_width(all_results, skip_headers) print_chart(all_results) From e69424ab973d751ef37f405802b6ad5cddf423ed Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Thu, 7 Aug 2025 23:13:50 -0700 Subject: [PATCH 4/8] remove commented code --- src/main/knn/KnnGraphTester.java | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 16acbbb2..6a952422 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -1114,7 +1114,6 @@ private ResultIds[][] getExactNN(Path docPath, Path indexPath, Path queryPath, i long startNS = System.nanoTime(); // TODO: enable computing NN from high precision vectors when // checking low-precision recall -// int[][] nn; ResultIds[][] nn; if (vectorEncoding.equals(VectorEncoding.BYTE)) { nn = computeExactNNByte(queryPath, queryStartIndex); @@ -1149,17 +1148,6 @@ private ResultIds[][] readExactNN(Path nnPath) throws IOException { } catch (ClassNotFoundException e) { throw new RuntimeException(e); } -// -// try (FileChannel in = FileChannel.open(nnPath)) { -// IntBuffer intBuffer = -// in.map(FileChannel.MapMode.READ_ONLY, 0, numQueryVectors * topK * Integer.BYTES) -// .order(ByteOrder.LITTLE_ENDIAN) -// .asIntBuffer(); -// for (int i = 0; i < numQueryVectors; i++) { -// nn[i] = new int[topK]; -// intBuffer.get(nn[i]); -// } -// } return nn; } From 2543a6303efce613f19b194bec3fac276560aaaa Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Thu, 7 Aug 2025 23:17:54 -0700 Subject: [PATCH 5/8] run params --- src/python/knnPerfTest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index a5590ec1..cc13ef38 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -53,7 +53,7 @@ #'ndoc': (2_000_000,), #'ndoc': (1_000_000,), # "ndoc": (500_000,), - 'ndoc': (10_000, 50_000, 100_000), + 'ndoc': (10_000,), # "maxConn": (32, 64, 96), # "maxConn": (64,), 'maxConn': (32,), From 04905007c9a39afdc17f2764fb9cf8b665649d2f Mon Sep 17 00:00:00 2001 From: Vigya Sharma <> Date: Fri, 8 Aug 2025 07:04:02 +0000 Subject: [PATCH 6/8] print rerank flag in summary --- src/main/knn/KnnGraphTester.java | 12 ++---------- src/python/knnPerfTest.py | 7 +++++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 6a952422..d2e83434 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -955,10 +955,11 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat double reindexSec = reindexTimeMsec / 1000.0; System.out.printf( Locale.ROOT, - "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", + "SUMMARY: %5.3f\t%5.3f\t%5.3f\t%s\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n", recall, ndcg10, ndcgK, + rerank, elapsedMS / (float) numQueryVectors, totalCpuTimeMS / (float) numQueryVectors, totalCpuTimeMS / (float) elapsedMS, @@ -1159,15 +1160,6 @@ private void writeExactNN(ResultIds[][] nn, Path nnPath) throws IOException { objectOutputStream.writeObject(nn[i]); } } - -// ByteBuffer tmp = -// ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); -// try (OutputStream out = Files.newOutputStream(nnPath)) { -// for (int i = 0; i < numQueryVectors; i++) { -// tmp.asIntBuffer().put(nn[i]); -// out.write(tmp.array()); -// } -// } } private ResultIds[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException { diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index cc13ef38..a4ea7068 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -53,7 +53,8 @@ #'ndoc': (2_000_000,), #'ndoc': (1_000_000,), # "ndoc": (500_000,), - 'ndoc': (10_000,), +# "ndoc": (10_000,), + 'ndoc': (100_000, 500_000, 1_000_000, 2_000_000, 10_000_000), # "maxConn": (32, 64, 96), # "maxConn": (64,), 'maxConn': (32,), @@ -61,7 +62,7 @@ # "beamWidthIndex": (250,), 'beamWidthIndex': (50,), # "fanout": (20, 50, 100, 250), - "fanout": (50,), + "fanout": (20,), #'quantize': None, #'quantizeBits': (32, 7, 4), "numMergeWorker": (12,), @@ -97,6 +98,7 @@ "recall", "ndcg@10", "ndcg@K", + "rerank", "latency(ms)", "netCPU", "avgCpuCount", @@ -284,6 +286,7 @@ def run_knn_benchmark(checkout, values): skip_headers.add("maxConn") skip_headers.add("beamWidth") if "-rerank" not in this_cmd: + skip_headers.add("rerank") skip_headers.add("ndcg@10") skip_headers.add("ndcg@K") From dd063318172e6abd2e2c2b30fdb54f228548fd6b Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Fri, 8 Aug 2025 10:52:46 -0700 Subject: [PATCH 7/8] pr ready --- src/main/knn/KnnGraphTester.java | 2 +- src/python/knnPerfTest.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index d2e83434..71bb99f9 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -937,7 +937,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat executorService.shutdown(); } } - // TODO: do we need to write nn here again? Didnt we already read it from some file? + // Do we need to write nn here again? We already wrote it in getExactNN() if (outputPath != null) { writeExactNN(nn, outputPath); } else { diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index a4ea7068..ca590625 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -52,17 +52,16 @@ #'ndoc': (10000, 100000, 200000, 500000), #'ndoc': (2_000_000,), #'ndoc': (1_000_000,), -# "ndoc": (500_000,), -# "ndoc": (10_000,), - 'ndoc': (100_000, 500_000, 1_000_000, 2_000_000, 10_000_000), -# "maxConn": (32, 64, 96), + "ndoc": (500_000,), + #'ndoc': (50_000,), + "maxConn": (32, 64, 96), # "maxConn": (64,), - 'maxConn': (32,), -# "beamWidthIndex": (250, 500), -# "beamWidthIndex": (250,), - 'beamWidthIndex': (50,), -# "fanout": (20, 50, 100, 250), - "fanout": (20,), + #'maxConn': (32,), + "beamWidthIndex": (250, 500), + # "beamWidthIndex": (250,), + #'beamWidthIndex': (50,), + "fanout": (20, 50, 100, 250), + # "fanout": (50,), #'quantize': None, #'quantizeBits': (32, 7, 4), "numMergeWorker": (12,), @@ -76,8 +75,8 @@ #'quantize': (True,), "quantizeBits": ( 4, -# 7, -# 32, + 7, + 32, ), # "quantizeBits": (1,), # "overSample": (5,), # extra ratio of vectors to retrieve, for testing approximate scoring, e.g. quantized indices @@ -90,7 +89,7 @@ "queryStartIndex": (0,), # seek to this start vector before searching, to sample different vectors # "forceMerge": (True, False), #'niter': (10,), - "rerank": (False, True), + # "rerank": (False, True), } From 79c763540055e074af255eb0002115a58ade1eda Mon Sep 17 00:00:00 2001 From: Vigya Sharma Date: Fri, 8 Aug 2025 20:19:00 -0700 Subject: [PATCH 8/8] fix reranking bug and always print ndcg --- src/main/knn/KnnGraphTester.java | 8 ++++---- src/main/knn/KnnTesterUtils.java | 4 ++-- src/python/knnPerfTest.py | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/main/knn/KnnGraphTester.java b/src/main/knn/KnnGraphTester.java index 71bb99f9..e0160f3a 100644 --- a/src/main/knn/KnnGraphTester.java +++ b/src/main/knn/KnnGraphTester.java @@ -917,7 +917,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat StoredFields storedFields = reader.storedFields(); for (int i = 0; i < numQueryVectors; i++) { totalVisited += results[i].visitedCount(); - resultIds[i] = KnnTesterUtils.getResultIds(results[i].topDocs(), storedFields, this.topK); + resultIds[i] = KnnTesterUtils.getResultIds(results[i].topDocs(), storedFields); } log( "completed " @@ -1210,7 +1210,7 @@ public Void call() { .add(filterQuery, BooleanClause.Occur.FILTER) .build(); var topDocs = searcher.search(query, topK); - result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields(), topK); + result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields()); if ((queryOrd + 1) % 10 == 0) { log(" " + (queryOrd + 1)); } @@ -1284,7 +1284,7 @@ public Void call() { .add(filterQuery, BooleanClause.Occur.FILTER) .build(); var topDocs = searcher.search(query, topK); - result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields(), topK); + result[queryOrd] = knn.KnnTesterUtils.getResultIds(topDocs, reader.storedFields()); if ((queryOrd + 1) % 10 == 0) { log(" " + (queryOrd + 1)); } @@ -1318,7 +1318,7 @@ public Void call() { ParentJoinBenchmarkQuery parentJoinQuery = new ParentJoinBenchmarkQuery(query, null, topK); TopDocs topHits = ParentJoinBenchmarkQuery.runExactSearch(reader, parentJoinQuery); StoredFields storedFields = reader.storedFields(); - result[queryOrd] = KnnTesterUtils.getResultIds(topHits, storedFields, topK); + result[queryOrd] = KnnTesterUtils.getResultIds(topHits, storedFields); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/knn/KnnTesterUtils.java b/src/main/knn/KnnTesterUtils.java index 1934b40e..c9093ba2 100644 --- a/src/main/knn/KnnTesterUtils.java +++ b/src/main/knn/KnnTesterUtils.java @@ -31,8 +31,8 @@ public class KnnTesterUtils { /** Fetches values for the "id" field from search results */ - public static ResultIds[] getResultIds(TopDocs topDocs, StoredFields storedFields, int k) throws IOException { - ResultIds[] resultIds = new ResultIds[k]; + public static ResultIds[] getResultIds(TopDocs topDocs, StoredFields storedFields) throws IOException { + ResultIds[] resultIds = new ResultIds[topDocs.scoreDocs.length]; int i = 0; // TODO: switch to doc values for this id field? more efficent than stored fields // TODO: or, at least load the stored documents in index (Lucene docid) order to diff --git a/src/python/knnPerfTest.py b/src/python/knnPerfTest.py index ca590625..2bd4bcbe 100644 --- a/src/python/knnPerfTest.py +++ b/src/python/knnPerfTest.py @@ -286,8 +286,6 @@ def run_knn_benchmark(checkout, values): skip_headers.add("beamWidth") if "-rerank" not in this_cmd: skip_headers.add("rerank") - skip_headers.add("ndcg@10") - skip_headers.add("ndcg@K") print_fixed_width(all_results, skip_headers) print_chart(all_results)