@@ -7,9 +7,11 @@ package tests
77
88import (
99 "context"
10+ "encoding/binary"
1011 "fmt"
1112 "math/rand"
1213 "strings"
14+ "sync/atomic"
1315 "time"
1416
1517 "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster"
@@ -19,8 +21,13 @@ import (
1921 "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
2022 "github.com/cockroachdb/cockroach/pkg/roachprod"
2123 "github.com/cockroachdb/cockroach/pkg/roachprod/install"
24+ "github.com/cockroachdb/cockroach/pkg/roachprod/logger"
2225 "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
26+ "github.com/cockroachdb/cockroach/pkg/util/timeutil"
27+ "github.com/cockroachdb/cockroach/pkg/util/vector"
2328 "github.com/cockroachdb/cockroach/pkg/workload/vecann"
29+ "github.com/cockroachdb/errors"
30+ "github.com/jackc/pgconn"
2431 "github.com/jackc/pgx/v5/pgxpool"
2532 "github.com/stretchr/testify/require"
2633)
@@ -89,6 +96,58 @@ func getOpClass(metric vecpb.DistanceMetric) string {
8996 }
9097}
9198
99+ // makeCanonicalKey generates 4-byte primary key from dataset index
100+ func makeCanonicalKey (category int , datasetIdx int ) []byte {
101+ key := make ([]byte , 0 , 6 )
102+ key = binary .BigEndian .AppendUint16 (key , uint16 (category ))
103+ key = binary .BigEndian .AppendUint32 (key , uint32 (datasetIdx ))
104+ return key
105+ }
106+
107+ // backfillState represents the current state of index backfill
108+ type backfillState int32
109+
110+ const (
111+ statePreBackfill backfillState = iota // Loading unindexed data
112+ stateIndexCreating // CREATE INDEX issued, job not visible yet
113+ stateBackfillRunning // Backfill job visible and running
114+ stateBackfillComplete // Backfill done, still loading canonical rows
115+ stateSteadyState // All done (used in Phase 2)
116+ )
117+
118+ func (s backfillState ) String () string {
119+ switch s {
120+ case statePreBackfill :
121+ return "pre-backfill"
122+ case stateIndexCreating :
123+ return "index-creating"
124+ case stateBackfillRunning :
125+ return "backfill-running"
126+ case stateBackfillComplete :
127+ return "backfill-complete"
128+ case stateSteadyState :
129+ return "steady-state"
130+ default :
131+ return fmt .Sprintf ("unknown(%d)" , s )
132+ }
133+ }
134+
135+ // phaseString returns the inserted_phase value for the current state
136+ func (s backfillState ) phaseString () string {
137+ switch s {
138+ case statePreBackfill :
139+ return "initial"
140+ case stateIndexCreating , stateBackfillRunning :
141+ return "during-backfill"
142+ case stateBackfillComplete :
143+ return "post-backfill-canonical"
144+ case stateSteadyState :
145+ return "post-backfill"
146+ default :
147+ return "unknown"
148+ }
149+ }
150+
92151func registerVectorIndex (r registry.Registry ) {
93152 configs := []vecIndexOptions {
94153 // Standard - no prefix
@@ -215,5 +274,178 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
215274 defer pool .Close ()
216275 require .NoError (t , pool .Ping (ctx ))
217276
218- t .L ().Printf ("=== Test completed successfully ===" )
277+ t .L ().Printf ("Creating schema and loading data" )
278+ testBackfillAndMerge (ctx , t , c , pool , & loader .Data , & opts , metric )
279+ }
280+
281+ func testBackfillAndMerge (
282+ ctx context.Context ,
283+ t test.Test ,
284+ c cluster.Cluster ,
285+ pool * pgxpool.Pool ,
286+ data * vecann.Dataset ,
287+ opts * vecIndexOptions ,
288+ metric vecpb.DistanceMetric ,
289+ ) {
290+ batchSize := opts .preBatchSz
291+ numCategories := max (opts .prefixCount , 1 )
292+
293+ db , err := c .ConnE (ctx , t .L (), 1 )
294+ require .NoError (t , err )
295+ defer db .Close ()
296+
297+ _ , err = db .ExecContext (ctx , fmt .Sprintf (
298+ `CREATE TABLE vecindex_test (
299+ id BYTES,
300+ category INT NOT NULL,
301+ embedding VECTOR(%d) NOT NULL,
302+ inserted_phase TEXT NOT NULL,
303+ worker_id INT NOT NULL,
304+ excluded BOOL DEFAULT false,
305+ INDEX (excluded),
306+ PRIMARY KEY (id)
307+ )` , data .Dims ))
308+ require .NoError (t , err )
309+ t .L ().Printf ("Table vecindex_test created" )
310+
311+ // Shared state for workers
312+ var state int32 = int32 (statePreBackfill )
313+ var rowsInserted int32
314+ blockBackfill := make (chan struct {})
315+ createIndexStartThresh := (data .TrainCount * opts .backfillPct ) / 100
316+
317+ ci := t .NewGroup ()
318+
319+ // Start a goroutine to create a vector index when we're signaled by one of the workers
320+ ci .Go (func (ctx context.Context , l * logger.Logger ) error {
321+ // Wait until a worker signals us to start the CREATE INDEX
322+ <- blockBackfill
323+
324+ atomic .StoreInt32 (& state , int32 (stateIndexCreating ))
325+ l .Printf ("Executing CREATE VECTOR INDEX at %d rows" , atomic .LoadInt32 (& rowsInserted ))
326+
327+ startCreateIndex := timeutil .Now ()
328+ opClass := getOpClass (metric )
329+ var indexSQL string
330+ if opts .prefixCount > 0 {
331+ indexSQL = fmt .Sprintf ("CREATE VECTOR INDEX vecidx ON vecindex_test (category, embedding%s)" , opClass )
332+ } else {
333+ indexSQL = fmt .Sprintf ("CREATE VECTOR INDEX vecidx ON vecindex_test (embedding%s)" , opClass )
334+ }
335+
336+ _ , err := db .ExecContext (ctx , indexSQL )
337+ if err != nil {
338+ return errors .Wrapf (err , "Failed to create vector index" )
339+ } else {
340+ dur := timeutil .Since (startCreateIndex ).Truncate (time .Second )
341+ rate := float64 (data .TrainCount ) / dur .Seconds ()
342+ l .Printf ("CREATE VECTOR INDEX completed in %v (%.1f rows per second)" , dur , rate )
343+ }
344+ return nil
345+ })
346+
347+ t .L ().Printf (
348+ "Loading %d rows into %d categories (%d rows total)" ,
349+ data .TrainCount ,
350+ numCategories ,
351+ data .TrainCount * numCategories ,
352+ )
353+
354+ var fileStart int
355+ loadStart := timeutil .Now ()
356+ // Iterate through the data files in the data set
357+ for {
358+ filename := data .GetNextTrainFile ()
359+ hasMore , err := data .Next ()
360+ require .NoError (t , err )
361+ if ! hasMore {
362+ dur := timeutil .Since (loadStart ).Truncate (time .Second )
363+ rate := float64 (data .TrainCount ) / dur .Seconds ()
364+ t .L ().Printf ("Data loaded in %v (%.1f rows per second)" , dur , rate )
365+ break
366+ }
367+ t .L ().Printf ("Loading data file: %s" , filename )
368+
369+ // Create workers to load this data file and dispatch part of the file to each of them.
370+ m := t .NewGroup ()
371+ countPerProc := (data .Train .Count / opts .workers ) + 1
372+ for worker := range opts .workers {
373+ start := worker * countPerProc
374+ end := min (start + countPerProc , data .Train .Count )
375+ m .Go (func (ctx context.Context , l * logger.Logger ) error {
376+ conn , err := pool .Acquire (ctx )
377+ require .NoError (t , err )
378+ defer conn .Release ()
379+
380+ for j := start ; j < end ; j += batchSize {
381+ sz := min (j + batchSize , end ) - j
382+ ri := int (atomic .AddInt32 (& rowsInserted , int32 (sz )))
383+ phaseStr := backfillState (atomic .LoadInt32 (& state )).phaseString ()
384+ vectors := data .Train .Slice (j , sz )
385+ err := insertVectors (ctx , conn , worker , numCategories , fileStart + j , phaseStr , vectors , false )
386+ if err != nil {
387+ return err
388+ }
389+ // If this is the batch that spanned the create index start threshold, signal the creator.
390+ if ri > createIndexStartThresh - sz && ri <= createIndexStartThresh {
391+ close (blockBackfill )
392+ }
393+ }
394+ return nil
395+ })
396+ }
397+ // Wait for this batch of loaders
398+ m .Wait ()
399+ }
400+
401+ // Wait for create index to finish
402+ ci .Wait ()
403+ }
404+
405+ func insertVectors (
406+ ctx context.Context ,
407+ conn * pgxpool.Conn ,
408+ workerID int ,
409+ numCats int ,
410+ startIdx int ,
411+ phaseStr string ,
412+ vectors vector.Set ,
413+ excluded bool ,
414+ ) error {
415+ args := make ([]any , vectors .Count * numCats * 5 )
416+ var queryBuilder strings.Builder
417+ queryBuilder .Grow (100 + vectors .Count * numCats * 34 )
418+ queryBuilder .WriteString ("INSERT INTO vecindex_test " +
419+ "(id, category, embedding, inserted_phase, worker_id, excluded) VALUES " )
420+ rowNum := 0
421+ for i := range vectors .Count {
422+ for cat := range numCats {
423+ if rowNum > 0 {
424+ queryBuilder .WriteString (", " )
425+ }
426+ j := rowNum * 5
427+ fmt .Fprintf (& queryBuilder , "($%d, $%d, $%d, $%d, $%d, %v)" , j + 1 , j + 2 , j + 3 , j + 4 , j + 5 , excluded )
428+ args [j ] = makeCanonicalKey (cat , startIdx + i )
429+ args [j + 1 ] = cat
430+ args [j + 2 ] = vectors .At (i )
431+ args [j + 3 ] = phaseStr
432+ args [j + 4 ] = workerID
433+ rowNum ++
434+ }
435+ }
436+ query := queryBuilder .String ()
437+
438+ for {
439+ _ , err := conn .Exec (ctx , query , args ... )
440+
441+ var pgErr * pgconn.PgError
442+ if err != nil && errors .As (err , & pgErr ) {
443+ switch pgErr .Code {
444+ case "40001" , "40P01" :
445+ continue
446+ }
447+ }
448+
449+ return errors .Wrapf (err , "Failed to run: %s %v" , query , args )
450+ }
219451}
0 commit comments