Skip to content

Commit 992e202

Browse files
committed
roachtest/vecindex: implement a test of backfill and merge
Add a test phase that loads a dataset using a pool of workers and, at a test-specified percentage of table population, kicks off a create vector index for the data. This allows us to test both backfill (pre-create) and merge (post-create starting). Times are reported for both but are not used as a pass criteria (yet). Informs: #154590 Release note: None
1 parent 0b5659b commit 992e202

File tree

2 files changed

+235
-1
lines changed

2 files changed

+235
-1
lines changed

pkg/cmd/roachtest/tests/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ go_library(
312312
"//pkg/util/syncutil",
313313
"//pkg/util/timeutil",
314314
"//pkg/util/uuid",
315+
"//pkg/util/vector",
315316
"//pkg/workload",
316317
"//pkg/workload/debug",
317318
"//pkg/workload/histogram",
@@ -343,6 +344,7 @@ go_library(
343344
"@com_github_google_go_cmp//cmp",
344345
"@com_github_google_pprof//profile",
345346
"@com_github_ibm_sarama//:sarama",
347+
"@com_github_jackc_pgconn//:pgconn",
346348
"@com_github_jackc_pgtype//:pgtype",
347349
"@com_github_jackc_pgx_v5//:pgx",
348350
"@com_github_jackc_pgx_v5//pgxpool",

pkg/cmd/roachtest/tests/vecindex.go

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package tests
77

88
import (
99
"context"
10+
"encoding/binary"
1011
"fmt"
1112
"math/rand"
1213
"strings"
14+
"sync/atomic"
1315
"time"
1416

1517
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster"
@@ -19,8 +21,13 @@ import (
1921
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
2022
"github.com/cockroachdb/cockroach/pkg/roachprod"
2123
"github.com/cockroachdb/cockroach/pkg/roachprod/install"
24+
"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
2225
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecpb"
26+
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
27+
"github.com/cockroachdb/cockroach/pkg/util/vector"
2328
"github.com/cockroachdb/cockroach/pkg/workload/vecann"
29+
"github.com/cockroachdb/errors"
30+
"github.com/jackc/pgconn"
2431
"github.com/jackc/pgx/v5/pgxpool"
2532
"github.com/stretchr/testify/require"
2633
)
@@ -89,6 +96,58 @@ func getOpClass(metric vecpb.DistanceMetric) string {
8996
}
9097
}
9198

99+
// makeCanonicalKey generates 4-byte primary key from dataset index
100+
func makeCanonicalKey(category int, datasetIdx int) []byte {
101+
key := make([]byte, 0, 6)
102+
key = binary.BigEndian.AppendUint16(key, uint16(category))
103+
key = binary.BigEndian.AppendUint32(key, uint32(datasetIdx))
104+
return key
105+
}
106+
107+
// backfillState represents the current state of index backfill
108+
type backfillState int32
109+
110+
const (
111+
statePreBackfill backfillState = iota // Loading unindexed data
112+
stateIndexCreating // CREATE INDEX issued, job not visible yet
113+
stateBackfillRunning // Backfill job visible and running
114+
stateBackfillComplete // Backfill done, still loading canonical rows
115+
stateSteadyState // All done (used in Phase 2)
116+
)
117+
118+
func (s backfillState) String() string {
119+
switch s {
120+
case statePreBackfill:
121+
return "pre-backfill"
122+
case stateIndexCreating:
123+
return "index-creating"
124+
case stateBackfillRunning:
125+
return "backfill-running"
126+
case stateBackfillComplete:
127+
return "backfill-complete"
128+
case stateSteadyState:
129+
return "steady-state"
130+
default:
131+
return fmt.Sprintf("unknown(%d)", s)
132+
}
133+
}
134+
135+
// phaseString returns the inserted_phase value for the current state
136+
func (s backfillState) phaseString() string {
137+
switch s {
138+
case statePreBackfill:
139+
return "initial"
140+
case stateIndexCreating, stateBackfillRunning:
141+
return "during-backfill"
142+
case stateBackfillComplete:
143+
return "post-backfill-canonical"
144+
case stateSteadyState:
145+
return "post-backfill"
146+
default:
147+
return "unknown"
148+
}
149+
}
150+
92151
func registerVectorIndex(r registry.Registry) {
93152
configs := []vecIndexOptions{
94153
// Standard - no prefix
@@ -215,5 +274,178 @@ func runVectorIndex(ctx context.Context, t test.Test, c cluster.Cluster, opts ve
215274
defer pool.Close()
216275
require.NoError(t, pool.Ping(ctx))
217276

218-
t.L().Printf("=== Test completed successfully ===")
277+
t.L().Printf("Creating schema and loading data")
278+
testBackfillAndMerge(ctx, t, c, pool, &loader.Data, &opts, metric)
279+
}
280+
281+
func testBackfillAndMerge(
282+
ctx context.Context,
283+
t test.Test,
284+
c cluster.Cluster,
285+
pool *pgxpool.Pool,
286+
data *vecann.Dataset,
287+
opts *vecIndexOptions,
288+
metric vecpb.DistanceMetric,
289+
) {
290+
batchSize := opts.preBatchSz
291+
numCategories := max(opts.prefixCount, 1)
292+
293+
db, err := c.ConnE(ctx, t.L(), 1)
294+
require.NoError(t, err)
295+
defer db.Close()
296+
297+
_, err = db.ExecContext(ctx, fmt.Sprintf(
298+
`CREATE TABLE vecindex_test (
299+
id BYTES,
300+
category INT NOT NULL,
301+
embedding VECTOR(%d) NOT NULL,
302+
inserted_phase TEXT NOT NULL,
303+
worker_id INT NOT NULL,
304+
excluded BOOL DEFAULT false,
305+
INDEX (excluded),
306+
PRIMARY KEY (id)
307+
)`, data.Dims))
308+
require.NoError(t, err)
309+
t.L().Printf("Table vecindex_test created")
310+
311+
// Shared state for workers
312+
var state int32 = int32(statePreBackfill)
313+
var rowsInserted int32
314+
blockBackfill := make(chan struct{})
315+
createIndexStartThresh := (data.TrainCount * opts.backfillPct) / 100
316+
317+
ci := t.NewGroup()
318+
319+
// Start a goroutine to create a vector index when we're signaled by one of the workers
320+
ci.Go(func(ctx context.Context, l *logger.Logger) error {
321+
// Wait until a worker signals us to start the CREATE INDEX
322+
<-blockBackfill
323+
324+
atomic.StoreInt32(&state, int32(stateIndexCreating))
325+
l.Printf("Executing CREATE VECTOR INDEX at %d rows", atomic.LoadInt32(&rowsInserted))
326+
327+
startCreateIndex := timeutil.Now()
328+
opClass := getOpClass(metric)
329+
var indexSQL string
330+
if opts.prefixCount > 0 {
331+
indexSQL = fmt.Sprintf("CREATE VECTOR INDEX vecidx ON vecindex_test (category, embedding%s)", opClass)
332+
} else {
333+
indexSQL = fmt.Sprintf("CREATE VECTOR INDEX vecidx ON vecindex_test (embedding%s)", opClass)
334+
}
335+
336+
_, err := db.ExecContext(ctx, indexSQL)
337+
if err != nil {
338+
return errors.Wrapf(err, "Failed to create vector index")
339+
} else {
340+
dur := timeutil.Since(startCreateIndex).Truncate(time.Second)
341+
rate := float64(data.TrainCount) / dur.Seconds()
342+
l.Printf("CREATE VECTOR INDEX completed in %v (%.1f rows per second)", dur, rate)
343+
}
344+
return nil
345+
})
346+
347+
t.L().Printf(
348+
"Loading %d rows into %d categories (%d rows total)",
349+
data.TrainCount,
350+
numCategories,
351+
data.TrainCount*numCategories,
352+
)
353+
354+
var fileStart int
355+
loadStart := timeutil.Now()
356+
// Iterate through the data files in the data set
357+
for {
358+
filename := data.GetNextTrainFile()
359+
hasMore, err := data.Next()
360+
require.NoError(t, err)
361+
if !hasMore {
362+
dur := timeutil.Since(loadStart).Truncate(time.Second)
363+
rate := float64(data.TrainCount) / dur.Seconds()
364+
t.L().Printf("Data loaded in %v (%.1f rows per second)", dur, rate)
365+
break
366+
}
367+
t.L().Printf("Loading data file: %s", filename)
368+
369+
// Create workers to load this data file and dispatch part of the file to each of them.
370+
m := t.NewGroup()
371+
countPerProc := (data.Train.Count / opts.workers) + 1
372+
for worker := range opts.workers {
373+
start := worker * countPerProc
374+
end := min(start+countPerProc, data.Train.Count)
375+
m.Go(func(ctx context.Context, l *logger.Logger) error {
376+
conn, err := pool.Acquire(ctx)
377+
require.NoError(t, err)
378+
defer conn.Release()
379+
380+
for j := start; j < end; j += batchSize {
381+
sz := min(j+batchSize, end) - j
382+
ri := int(atomic.AddInt32(&rowsInserted, int32(sz)))
383+
phaseStr := backfillState(atomic.LoadInt32(&state)).phaseString()
384+
vectors := data.Train.Slice(j, sz)
385+
err := insertVectors(ctx, conn, worker, numCategories, fileStart+j, phaseStr, vectors, false)
386+
if err != nil {
387+
return err
388+
}
389+
// If this is the batch that spanned the create index start threshold, signal the creator.
390+
if ri > createIndexStartThresh-sz && ri <= createIndexStartThresh {
391+
close(blockBackfill)
392+
}
393+
}
394+
return nil
395+
})
396+
}
397+
// Wait for this batch of loaders
398+
m.Wait()
399+
}
400+
401+
// Wait for create index to finish
402+
ci.Wait()
403+
}
404+
405+
func insertVectors(
406+
ctx context.Context,
407+
conn *pgxpool.Conn,
408+
workerID int,
409+
numCats int,
410+
startIdx int,
411+
phaseStr string,
412+
vectors vector.Set,
413+
excluded bool,
414+
) error {
415+
args := make([]any, vectors.Count*numCats*5)
416+
var queryBuilder strings.Builder
417+
queryBuilder.Grow(100 + vectors.Count*numCats*34)
418+
queryBuilder.WriteString("INSERT INTO vecindex_test " +
419+
"(id, category, embedding, inserted_phase, worker_id, excluded) VALUES ")
420+
rowNum := 0
421+
for i := range vectors.Count {
422+
for cat := range numCats {
423+
if rowNum > 0 {
424+
queryBuilder.WriteString(", ")
425+
}
426+
j := rowNum * 5
427+
fmt.Fprintf(&queryBuilder, "($%d, $%d, $%d, $%d, $%d, %v)", j+1, j+2, j+3, j+4, j+5, excluded)
428+
args[j] = makeCanonicalKey(cat, startIdx+i)
429+
args[j+1] = cat
430+
args[j+2] = vectors.At(i)
431+
args[j+3] = phaseStr
432+
args[j+4] = workerID
433+
rowNum++
434+
}
435+
}
436+
query := queryBuilder.String()
437+
438+
for {
439+
_, err := conn.Exec(ctx, query, args...)
440+
441+
var pgErr *pgconn.PgError
442+
if err != nil && errors.As(err, &pgErr) {
443+
switch pgErr.Code {
444+
case "40001", "40P01":
445+
continue
446+
}
447+
}
448+
449+
return errors.Wrapf(err, "Failed to run: %s %v", query, args)
450+
}
219451
}

0 commit comments

Comments
 (0)