2222
2323import com .apple .foundationdb .Database ;
2424import com .apple .foundationdb .Transaction ;
25+ import com .apple .foundationdb .async .AsyncUtil ;
26+ import com .apple .foundationdb .async .hnsw .Vector .HalfVector ;
2527import com .apple .foundationdb .async .rtree .RTree ;
2628import com .apple .foundationdb .test .TestDatabaseExtension ;
2729import com .apple .foundationdb .test .TestExecutors ;
3638import org .junit .jupiter .api .BeforeEach ;
3739import org .junit .jupiter .api .Tag ;
3840import org .junit .jupiter .api .Test ;
41+ import org .junit .jupiter .api .Timeout ;
3942import org .junit .jupiter .api .extension .RegisterExtension ;
4043import org .junit .jupiter .api .parallel .Execution ;
4144import org .junit .jupiter .api .parallel .ExecutionMode ;
45+ import org .junit .jupiter .params .ParameterizedTest ;
46+ import org .junit .jupiter .params .provider .ValueSource ;
4247import org .slf4j .Logger ;
4348import org .slf4j .LoggerFactory ;
4449
4550import javax .annotation .Nonnull ;
51+ import java .io .BufferedReader ;
4652import java .io .BufferedWriter ;
53+ import java .io .FileReader ;
4754import java .io .FileWriter ;
4855import java .io .IOException ;
4956import java .util .ArrayList ;
5057import java .util .Comparator ;
5158import java .util .List ;
5259import java .util .Map ;
60+ import java .util .NavigableSet ;
61+ import java .util .Objects ;
5362import java .util .Random ;
63+ import java .util .concurrent .CompletableFuture ;
64+ import java .util .concurrent .ConcurrentSkipListSet ;
5465import java .util .concurrent .TimeUnit ;
5566import java .util .concurrent .atomic .AtomicLong ;
67+ import java .util .concurrent .atomic .AtomicReference ;
68+ import java .util .function .Function ;
5669
5770/**
5871 * Tests testing insert/update/deletes of data into/in/from {@link RTree}s.
@@ -159,18 +172,20 @@ public void testBasicInsert() {
159172
160173 final TestOnReadListener onReadListener = new TestOnReadListener ();
161174
175+ final int dimensions = 128 ;
162176 final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
163- HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .COSINE_METRIC ). setEfConstruction ( 34 ). setM (16 ).setMMax (16 ).setMMax0 (32 ).build (),
177+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .EUCLIDEAN_METRIC ). setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
164178 OnWriteListener .NOOP , onReadListener );
165179
166- for (int i = 0 ; i < 10000 ;) {
167- i += basicInsertBatch (hnsw , random , 100 , nextNodeIdAtomic , onReadListener );
180+ for (int i = 0 ; i < 1000 ;) {
181+ i += basicInsertBatch (100 , nextNodeIdAtomic , onReadListener ,
182+ tr -> hnsw .insert (tr , createNextPrimaryKey (nextNodeIdAtomic ), createRandomVector (random , dimensions )));
168183 }
169184
170185 onReadListener .reset ();
171186 final long beginTs = System .nanoTime ();
172187 final List <? extends NodeReferenceAndNode <?>> result =
173- db .run (tr -> hnsw .kNearestNeighborsSearch (tr , 10 , 20 , createRandomVector (random , 768 )).join ());
188+ db .run (tr -> hnsw .kNearestNeighborsSearch (tr , 10 , 100 , createRandomVector (random , dimensions )).join ());
174189 final long endTs = System .nanoTime ();
175190
176191 for (NodeReferenceAndNode <?> nodeReferenceAndNode : result ) {
@@ -184,14 +199,15 @@ public void testBasicInsert() {
184199 logger .info ("search transaction took elapsedTime={}ms" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
185200 }
186201
187- private int basicInsertBatch (@ Nonnull final HNSW hnsw , @ Nonnull final Random random , final int batchSize ,
188- @ Nonnull final AtomicLong nextNodeIdAtomic , @ Nonnull final TestOnReadListener onReadListener ) {
202+ private int basicInsertBatch (final int batchSize ,
203+ @ Nonnull final AtomicLong nextNodeIdAtomic , @ Nonnull final TestOnReadListener onReadListener ,
204+ @ Nonnull final Function <Transaction , CompletableFuture <Void >> insertFunction ) {
189205 return db .run (tr -> {
190206 onReadListener .reset ();
191207 final long nextNodeId = nextNodeIdAtomic .get ();
192208 final long beginTs = System .nanoTime ();
193209 for (int i = 0 ; i < batchSize ; i ++) {
194- hnsw . insert (tr , createNextPrimaryKey ( nextNodeIdAtomic ), createRandomVector ( random , 768 ) ).join ();
210+ insertFunction . apply (tr ).join ();
195211 }
196212 final long endTs = System .nanoTime ();
197213 logger .info ("inserted batchSize={} records starting at nodeId={} took elapsedTime={}ms, readCounts={}, MSums={}" , batchSize , nextNodeId ,
@@ -200,6 +216,91 @@ private int basicInsertBatch(@Nonnull final HNSW hnsw, @Nonnull final Random ran
200216 });
201217 }
202218
219+ @ Test
220+ @ Timeout (value = 150 , unit = TimeUnit .MINUTES )
221+ public void testSIFTInsert10k () throws Exception {
222+ final Metric metric = Metric .EUCLIDEAN_METRIC ;
223+ final int k = 10 ;
224+ final AtomicLong nextNodeIdAtomic = new AtomicLong (0L );
225+
226+ final TestOnReadListener onReadListener = new TestOnReadListener ();
227+
228+ final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
229+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (metric ).setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
230+ OnWriteListener .NOOP , onReadListener );
231+
232+ final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv" ;
233+ final int dimensions = 128 ;
234+
235+ final AtomicReference <HalfVector > queryVectorAtomic = new AtomicReference <>();
236+ final NavigableSet <NodeReferenceWithDistance > trueResults = new ConcurrentSkipListSet <>(
237+ Comparator .comparing (NodeReferenceWithDistance ::getDistance ));
238+
239+ try (BufferedReader br = new BufferedReader (new FileReader (tsvFile ))) {
240+ for (int i = 0 ; i < 10000 ;) {
241+ i += basicInsertBatch (100 , nextNodeIdAtomic , onReadListener ,
242+ tr -> {
243+ final String line ;
244+ try {
245+ line = br .readLine ();
246+ } catch (IOException e ) {
247+ throw new RuntimeException (e );
248+ }
249+
250+ final String [] values = Objects .requireNonNull (line ).split ("\t " );
251+ Assertions .assertEquals (dimensions , values .length );
252+ final Half [] halfs = new Half [dimensions ];
253+
254+ for (int c = 0 ; c < values .length ; c ++) {
255+ final String value = values [c ];
256+ halfs [c ] = HNSWHelpers .halfValueOf (Double .parseDouble (value ));
257+ }
258+ final Tuple currentPrimaryKey = createNextPrimaryKey (nextNodeIdAtomic );
259+ final HalfVector currentVector = new HalfVector (halfs );
260+ final HalfVector queryVector = queryVectorAtomic .get ();
261+ if (queryVector == null ) {
262+ queryVectorAtomic .set (currentVector );
263+ return AsyncUtil .DONE ;
264+ } else {
265+ final double currentDistance =
266+ Vector .comparativeDistance (metric , currentVector , queryVector );
267+ if (trueResults .size () < k || trueResults .last ().getDistance () > currentDistance ) {
268+ trueResults .add (
269+ new NodeReferenceWithDistance (currentPrimaryKey , currentVector ,
270+ Vector .comparativeDistance (metric , currentVector , queryVector )));
271+ }
272+ if (trueResults .size () > k ) {
273+ trueResults .remove (trueResults .last ());
274+ }
275+ return hnsw .insert (tr , currentPrimaryKey , currentVector );
276+ }
277+ });
278+ }
279+ }
280+
281+ onReadListener .reset ();
282+ final long beginTs = System .nanoTime ();
283+ final List <? extends NodeReferenceAndNode <?>> results =
284+ db .run (tr -> hnsw .kNearestNeighborsSearch (tr , k , 100 , queryVectorAtomic .get ()).join ());
285+ final long endTs = System .nanoTime ();
286+
287+ for (NodeReferenceAndNode <?> nodeReferenceAndNode : results ) {
288+ final NodeReferenceWithDistance nodeReferenceWithDistance = nodeReferenceAndNode .getNodeReferenceWithDistance ();
289+ logger .info ("retrieved result nodeId = {} at distance= {}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
290+ nodeReferenceWithDistance .getDistance ());
291+ }
292+
293+ for (final NodeReferenceWithDistance nodeReferenceWithDistance : trueResults ) {
294+ logger .info ("true result nodeId ={} at distance={}" , nodeReferenceWithDistance .getPrimaryKey ().getLong (0 ),
295+ nodeReferenceWithDistance .getDistance ());
296+ }
297+
298+ System .out .println (onReadListener .getNodeCountByLayer ());
299+ System .out .println (onReadListener .getBytesReadByLayer ());
300+
301+ logger .info ("search transaction took elapsedTime={}ms" , TimeUnit .NANOSECONDS .toMillis (endTs - beginTs ));
302+ }
303+
203304 @ Test
204305 public void testBasicInsertAndScanLayer () throws Exception {
205306 final Random random = new Random (0 );
@@ -224,17 +325,92 @@ public void testBasicInsertAndScanLayer() throws Exception {
224325 }
225326
226327 @ Test
227- public void testManyVectors () {
328+ public void testManyRandomVectors () {
228329 final Random random = new Random ();
229330 for (long l = 0L ; l < 3000000 ; l ++) {
230- final Vector . HalfVector randomVector = createRandomVector (random , 768 );
331+ final HalfVector randomVector = createRandomVector (random , 768 );
231332 final Tuple vectorTuple = StorageAdapter .tupleFromVector (randomVector );
232333 final Vector <Half > roundTripVector = StorageAdapter .vectorFromTuple (vectorTuple );
233334 Vector .comparativeDistance (Metric .EuclideanMetric .EUCLIDEAN_METRIC , randomVector , roundTripVector );
234335 Assertions .assertEquals (randomVector , roundTripVector );
235336 }
236337 }
237338
339+ @ Test
340+ @ Timeout (value = 150 , unit = TimeUnit .MINUTES )
341+ public void testSIFTVectors () throws Exception {
342+ final AtomicLong nextNodeIdAtomic = new AtomicLong (0L );
343+
344+ final TestOnReadListener onReadListener = new TestOnReadListener ();
345+
346+ final HNSW hnsw = new HNSW (rtSubspace .getSubspace (), TestExecutors .defaultThreadPool (),
347+ HNSW .DEFAULT_CONFIG .toBuilder ().setMetric (Metric .EUCLIDEAN_METRIC ).setM (32 ).setMMax (32 ).setMMax0 (64 ).build (),
348+ OnWriteListener .NOOP , onReadListener );
349+
350+
351+ final String tsvFile = "/Users/nseemann/Downloads/train-100k.tsv" ;
352+ final int dimensions = 128 ;
353+ final var referenceVector = createRandomVector (new Random (0 ), dimensions );
354+ long count = 0L ;
355+ double mean = 0.0d ;
356+ double mean2 = 0.0d ;
357+
358+ try (BufferedReader br = new BufferedReader (new FileReader (tsvFile ))) {
359+ for (int i = 0 ; i < 100_000 ; i ++) {
360+ final String line ;
361+ try {
362+ line = br .readLine ();
363+ } catch (IOException e ) {
364+ throw new RuntimeException (e );
365+ }
366+
367+ final String [] values = Objects .requireNonNull (line ).split ("\t " );
368+ Assertions .assertEquals (dimensions , values .length );
369+ final Half [] halfs = new Half [dimensions ];
370+ for (int c = 0 ; c < values .length ; c ++) {
371+ final String value = values [c ];
372+ halfs [c ] = HNSWHelpers .halfValueOf (Double .parseDouble (value ));
373+ }
374+ final HalfVector newVector = new HalfVector (halfs );
375+ final double distance = Vector .comparativeDistance (Metric .EUCLIDEAN_METRIC , referenceVector , newVector );
376+ count ++;
377+ final double delta = distance - mean ;
378+ mean += delta / count ;
379+ final double delta2 = distance - mean ;
380+ mean2 += delta * delta2 ;
381+ }
382+ }
383+ final double sampleVariance = mean2 / (count - 1 );
384+ final double standardDeviation = Math .sqrt (sampleVariance );
385+ logger .info ("mean={}, sample_variance={}, stddeviation={}, cv={}" , mean , sampleVariance , standardDeviation ,
386+ standardDeviation / mean );
387+ }
388+
389+
390+ @ ParameterizedTest
391+ @ ValueSource (ints = {2 , 3 , 10 , 100 , 768 })
392+ public void testManyVectorsStandardDeviation (final int dimensionality ) {
393+ final Random random = new Random ();
394+ final Metric metric = Metric .EuclideanMetric .EUCLIDEAN_METRIC ;
395+ long count = 0L ;
396+ double mean = 0.0d ;
397+ double mean2 = 0.0d ;
398+ for (long i = 0L ; i < 100000 ; i ++) {
399+ final HalfVector vector1 = createRandomVector (random , dimensionality );
400+ final HalfVector vector2 = createRandomVector (random , dimensionality );
401+ final double distance = Vector .comparativeDistance (metric , vector1 , vector2 );
402+ count = i + 1 ;
403+ final double delta = distance - mean ;
404+ mean += delta / count ;
405+ final double delta2 = distance - mean ;
406+ mean2 += delta * delta2 ;
407+ }
408+ final double sampleVariance = mean2 / (count - 1 );
409+ final double standardDeviation = Math .sqrt (sampleVariance );
410+ logger .info ("mean={}, sample_variance={}, stddeviation={}, cv={}" , mean , sampleVariance , standardDeviation ,
411+ standardDeviation / mean );
412+ }
413+
238414 private boolean dumpLayer (final HNSW hnsw , final int layer ) throws IOException {
239415 final String verticesFileName = "/Users/nseemann/Downloads/vertices-" + layer + ".csv" ;
240416 final String edgesFileName = "/Users/nseemann/Downloads/edges-" + layer + ".csv" ;
@@ -324,13 +500,13 @@ private static Tuple createNextPrimaryKey(@Nonnull final AtomicLong nextIdAtomic
324500 }
325501
326502 @ Nonnull
327- private Vector . HalfVector createRandomVector (@ Nonnull final Random random , final int dimensionality ) {
503+ private HalfVector createRandomVector (@ Nonnull final Random random , final int dimensionality ) {
328504 final Half [] components = new Half [dimensionality ];
329505 for (int d = 0 ; d < dimensionality ; d ++) {
330506 // don't ask
331507 components [d ] = HNSWHelpers .halfValueOf (random .nextDouble ());
332508 }
333- return new Vector . HalfVector (components );
509+ return new HalfVector (components );
334510 }
335511
336512 private static class TestOnReadListener implements OnReadListener {
0 commit comments