5959import java .util .function .Function ;
6060import java .util .stream .Collectors ;
6161
62+ import static com .apple .foundationdb .async .MoreAsyncUtil .forEach ;
63+ import static com .apple .foundationdb .async .MoreAsyncUtil .forLoop ;
64+
6265/**
6366 * TODO.
6467 */
@@ -70,6 +73,7 @@ public class HNSW {
7073
7174 public static final int MAX_CONCURRENT_NODE_READS = 16 ;
7275 public static final int MAX_CONCURRENT_NEIGHBOR_FETCHES = 3 ;
76+ public static final int MAX_CONCURRENT_SEARCHES = 10 ;
7377 @ Nonnull public static final Random DEFAULT_RANDOM = new Random (0L );
7478 @ Nonnull public static final Metric DEFAULT_METRIC = new Metric .EuclideanMetric ();
7579 public static final int DEFAULT_M = 16 ;
@@ -697,12 +701,17 @@ private <R extends NodeReference, N extends NodeReference, U> CompletableFuture<
697701 @ Nonnull final Iterable <R > nodeReferences ,
698702 @ Nonnull final Function <R , U > fetchBypassFunction ,
699703 @ Nonnull final BiFunction <R , Node <N >, U > biMapFunction ) {
700- return MoreAsyncUtil . forEach (nodeReferences ,
704+ return forEach (nodeReferences ,
701705 currentNeighborReference -> fetchNodeIfNecessaryAndApply (storageAdapter , readTransaction , layer ,
702706 currentNeighborReference , fetchBypassFunction , biMapFunction ), MAX_CONCURRENT_NODE_READS ,
703707 getExecutor ());
704708 }
705709
710+ @ Nonnull
711+ public CompletableFuture <Void > insert (@ Nonnull final Transaction transaction , @ Nonnull final NodeReferenceWithVector nodeReferenceWithVector ) {
712+ return insert (transaction , nodeReferenceWithVector .getPrimaryKey (), nodeReferenceWithVector .getVector ());
713+ }
714+
706715 @ Nonnull
707716 public CompletableFuture <Void > insert (@ Nonnull final Transaction transaction , @ Nonnull final Tuple newPrimaryKey ,
708717 @ Nonnull final Vector <Half > newVector ) {
@@ -720,9 +729,9 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
720729 new EntryNodeReference (newPrimaryKey , newVector , insertionLayer ), getOnWriteListener ());
721730 debug (l -> l .debug ("written entry node reference with key={} on layer={}" , newPrimaryKey , insertionLayer ));
722731 } else {
723- final int entryNodeLayer = entryNodeReference .getLayer ();
724- if (insertionLayer > entryNodeLayer ) {
725- writeLonelyNodes (transaction , newPrimaryKey , newVector , insertionLayer , entryNodeLayer );
732+ final int lMax = entryNodeReference .getLayer ();
733+ if (insertionLayer > lMax ) {
734+ writeLonelyNodes (transaction , newPrimaryKey , newVector , insertionLayer , lMax );
726735 StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
727736 new EntryNodeReference (newPrimaryKey , newVector , insertionLayer ), getOnWriteListener ());
728737 debug (l -> l .debug ("written entry node reference with key={} on layer={}" , newPrimaryKey , insertionLayer ));
@@ -757,13 +766,104 @@ public CompletableFuture<Void> insert(@Nonnull final Transaction transaction, @N
757766 }
758767
759768 @ Nonnull
760- private CompletableFuture <Void > insertIntoLayers (final @ Nonnull Transaction transaction ,
761- final @ Nonnull Tuple newPrimaryKey ,
762- final @ Nonnull Vector <Half > newVector ,
763- final NodeReferenceWithDistance nodeReference , final int lMax , final int insertionLayer ) {
764- debug (l -> {
765- l .debug ("nearest entry point at lMax={} is at key={}" , lMax , nodeReference .getPrimaryKey ());
766- });
769+ public CompletableFuture <Void > insertBatch (@ Nonnull final Transaction transaction ,
770+ @ Nonnull List <NodeReferenceWithVector > batch ) {
771+ final Metric metric = getConfig ().getMetric ();
772+
773+ // determine the layer each item should be inserted at
774+ final Random random = getConfig ().getRandom ();
775+ final List <NodeReferenceWithLayer > batchWithLayers = Lists .newArrayListWithCapacity (batch .size ());
776+ for (final NodeReferenceWithVector current : batch ) {
777+ batchWithLayers .add (new NodeReferenceWithLayer (current .getPrimaryKey (), current .getVector (),
778+ insertionLayer (random )));
779+ }
780+ // sort the layers in reverse order
781+ batchWithLayers .sort (Comparator .comparing (NodeReferenceWithLayer ::getL ).reversed ());
782+
783+ return StorageAdapter .fetchEntryNodeReference (transaction , getSubspace (), getOnReadListener ())
784+ .thenCompose (entryNodeReference -> {
785+ final int lMax = entryNodeReference == null ? -1 : entryNodeReference .getLayer ();
786+
787+ return forEach (batchWithLayers ,
788+ item -> {
789+ if (lMax == -1 ) {
790+ return CompletableFuture .completedFuture (null );
791+ }
792+
793+ final Vector <Half > itemVector = item .getVector ();
794+ final int itemL = item .getL ();
795+
796+ final NodeReferenceWithDistance initialNodeReference =
797+ new NodeReferenceWithDistance (entryNodeReference .getPrimaryKey (),
798+ entryNodeReference .getVector (),
799+ Vector .comparativeDistance (metric , entryNodeReference .getVector (), itemVector ));
800+
801+ return MoreAsyncUtil .forLoop (lMax , initialNodeReference ,
802+ layer -> layer > itemL ,
803+ layer -> layer - 1 ,
804+ (layer , previousNodeReference ) -> {
805+ final StorageAdapter <? extends NodeReference > storageAdapter = getStorageAdapterForLayer (layer );
806+ return greedySearchLayer (storageAdapter , transaction ,
807+ previousNodeReference , layer , itemVector );
808+ }, executor );
809+ }, MAX_CONCURRENT_SEARCHES , getExecutor ())
810+ .thenCompose (searchEntryReferences ->
811+ forLoop (0 , entryNodeReference ,
812+ index -> index < batchWithLayers .size (),
813+ index -> index + 1 ,
814+ (index , currentEntryNodeReference ) -> {
815+ final NodeReferenceWithLayer item = batchWithLayers .get (index );
816+ final Tuple itemPrimaryKey = item .getPrimaryKey ();
817+ final Vector <Half > itemVector = item .getVector ();
818+ final int itemL = item .getL ();
819+
820+ final EntryNodeReference newEntryNodeReference ;
821+ final int currentLMax ;
822+
823+ if (entryNodeReference == null ) {
824+ // this is the first node
825+ writeLonelyNodes (transaction , itemPrimaryKey , itemVector , itemL , -1 );
826+ newEntryNodeReference =
827+ new EntryNodeReference (itemPrimaryKey , itemVector , itemL );
828+ StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
829+ newEntryNodeReference , getOnWriteListener ());
830+ debug (l -> l .debug ("written entry node reference with key={} on layer={}" , itemPrimaryKey , itemL ));
831+
832+ return CompletableFuture .completedFuture (newEntryNodeReference );
833+ } else {
834+ currentLMax = currentEntryNodeReference .getLayer ();
835+ if (itemL > currentLMax ) {
836+ writeLonelyNodes (transaction , itemPrimaryKey , itemVector , itemL , lMax );
837+ newEntryNodeReference =
838+ new EntryNodeReference (itemPrimaryKey , itemVector , itemL );
839+ StorageAdapter .writeEntryNodeReference (transaction , getSubspace (),
840+ newEntryNodeReference , getOnWriteListener ());
841+ debug (l -> l .debug ("written entry node reference with key={} on layer={}" , itemPrimaryKey , itemL ));
842+ } else {
843+ newEntryNodeReference = entryNodeReference ;
844+ }
845+ }
846+
847+ debug (l -> l .debug ("entry node with key {} at layer {}" ,
848+ currentEntryNodeReference .getPrimaryKey (), currentLMax ));
849+
850+ final var currentSearchEntry =
851+ searchEntryReferences .get (index );
852+
853+ return insertIntoLayers (transaction , itemPrimaryKey , itemVector , currentSearchEntry ,
854+ lMax , itemL ).thenApply (ignored -> newEntryNodeReference );
855+ }, getExecutor ()));
856+ }).thenCompose (ignored -> AsyncUtil .DONE );
857+ }
858+
859+ @ Nonnull
860+ private CompletableFuture <Void > insertIntoLayers (@ Nonnull final Transaction transaction ,
861+ @ Nonnull final Tuple newPrimaryKey ,
862+ @ Nonnull final Vector <Half > newVector ,
863+ @ Nonnull final NodeReferenceWithDistance nodeReference ,
864+ final int lMax ,
865+ final int insertionLayer ) {
866+ debug (l -> l .debug ("nearest entry point at lMax={} is at key={}" , lMax , nodeReference .getPrimaryKey ()));
767867 return MoreAsyncUtil .<List <NodeReferenceWithDistance >>forLoop (Math .min (lMax , insertionLayer ), ImmutableList .of (nodeReference ),
768868 layer -> layer >= 0 ,
769869 layer -> layer - 1 ,
@@ -817,7 +917,7 @@ private <N extends NodeReference> CompletableFuture<List<NodeReferenceWithDistan
817917 }
818918
819919 final int currentMMax = layer == 0 ? getConfig ().getMMax0 () : getConfig ().getMMax ();
820- return MoreAsyncUtil . forEach (selectedNeighbors ,
920+ return forEach (selectedNeighbors ,
821921 selectedNeighbor -> {
822922 final Node <N > selectedNeighborNode = selectedNeighbor .getNode ();
823923 final NeighborsChangeSet <N > changeSet =
@@ -1110,4 +1210,43 @@ private void debug(@Nonnull final Consumer<Logger> loggerConsumer) {
11101210 loggerConsumer .accept (logger );
11111211 }
11121212 }
1213+
1214+ private static class NodeReferenceWithLayer extends NodeReferenceWithVector {
1215+ @ SuppressWarnings ("checkstyle:MemberName" )
1216+ private final int l ;
1217+
1218+ public NodeReferenceWithLayer (@ Nonnull final Tuple primaryKey , @ Nonnull final Vector <Half > vector ,
1219+ final int l ) {
1220+ super (primaryKey , vector );
1221+ this .l = l ;
1222+ }
1223+
1224+ public int getL () {
1225+ return l ;
1226+ }
1227+ }
1228+
1229+ private static class NodeReferenceWithSearchEntry extends NodeReferenceWithVector {
1230+ @ SuppressWarnings ("checkstyle:MemberName" )
1231+ private final int l ;
1232+ @ Nonnull
1233+ private final NodeReferenceWithDistance nodeReferenceWithDistance ;
1234+
1235+ public NodeReferenceWithSearchEntry (@ Nonnull final Tuple primaryKey , @ Nonnull final Vector <Half > vector ,
1236+ final int l ,
1237+ @ Nonnull final NodeReferenceWithDistance nodeReferenceWithDistance ) {
1238+ super (primaryKey , vector );
1239+ this .l = l ;
1240+ this .nodeReferenceWithDistance = nodeReferenceWithDistance ;
1241+ }
1242+
1243+ public int getL () {
1244+ return l ;
1245+ }
1246+
1247+ @ Nonnull
1248+ public NodeReferenceWithDistance getNodeReferenceWithDistance () {
1249+ return nodeReferenceWithDistance ;
1250+ }
1251+ }
11131252}
0 commit comments