diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index d25a243f2863..946325097036 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -586,6 +586,7 @@ ALL_TESTS = [ "//pkg/sql/privilege:privilege_test", "//pkg/sql/protoreflect:protoreflect_test", "//pkg/sql/querycache:querycache_test", + "//pkg/sql/queuefeed:queuefeed_test", "//pkg/sql/randgen:randgen_test", "//pkg/sql/regions:regions_test", "//pkg/sql/row:row_disallowed_imports_test", @@ -2228,6 +2229,9 @@ GO_TARGETS = [ "//pkg/sql/protoreflect:protoreflect_test", "//pkg/sql/querycache:querycache", "//pkg/sql/querycache:querycache_test", + "//pkg/sql/queuefeed/queuebase:queuebase", + "//pkg/sql/queuefeed:queuefeed", + "//pkg/sql/queuefeed:queuefeed_test", "//pkg/sql/randgen:randgen", "//pkg/sql/randgen:randgen_test", "//pkg/sql/rangeprober:range_prober", diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel index c51a84bc1d85..1b2bc3980a91 100644 --- a/pkg/cmd/roachtest/tests/BUILD.bazel +++ b/pkg/cmd/roachtest/tests/BUILD.bazel @@ -167,6 +167,7 @@ go_library( "ptp.go", "query_comparison_util.go", "queue.go", + "queuefeed.go", "quit.go", "rapid_restart.go", "rebalance_load.go", @@ -365,6 +366,7 @@ go_library( "@org_golang_google_protobuf//proto", "@org_golang_x_exp//maps", "@org_golang_x_oauth2//clientcredentials", + "@org_golang_x_sync//errgroup", "@org_golang_x_text//cases", "@org_golang_x_text//language", ], diff --git a/pkg/cmd/roachtest/tests/queuefeed.go b/pkg/cmd/roachtest/tests/queuefeed.go new file mode 100644 index 000000000000..a2a2432d9968 --- /dev/null +++ b/pkg/cmd/roachtest/tests/queuefeed.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package tests + +import ( + "context" + "database/sql" + "fmt" + "math/rand" + "strings" + "sync/atomic" + "time" + + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec" + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" + "github.com/cockroachdb/cockroach/pkg/roachprod/install" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func registerQueuefeed(r registry.Registry) { + r.Add(registry.TestSpec{ + Name: "queuefeed", + Owner: registry.OwnerCDC, + Cluster: r.MakeClusterSpec(4, spec.WorkloadNode()), + Run: func(ctx context.Context, t test.Test, c cluster.Cluster) { + runQueuefeed(ctx, t, c) + }, + CompatibleClouds: registry.AllClouds, + Suites: registry.Suites(registry.Nightly), + }) +} + +func runQueuefeed(ctx context.Context, t test.Test, c cluster.Cluster) { + c.Start(ctx, t.L(), option.DefaultStartOpts(), install.MakeClusterSettings(), c.CRDBNodes()) + + db := c.Conn(ctx, t.L(), 1) + defer db.Close() + + _, err := db.ExecContext(ctx, "SET CLUSTER SETTING kv.rangefeed.enabled = true") + require.NoError(t, errors.Wrap(err, "enabling rangefeeds")) + + t.Status("initializing kv workload") + c.Run(ctx, option.WithNodes(c.WorkloadNode()), + "./cockroach workload init kv --splits=100 {pgurl:1}") + + var tableID int64 + err = db.QueryRowContext(ctx, "SELECT id FROM system.namespace WHERE name = 'kv' and \"parentSchemaID\" <> 0;").Scan(&tableID) + require.NoError(t, err) + + t.Status("creating kv_queue") + _, err = db.ExecContext(ctx, "SELECT crdb_internal.create_queue_feed('kv_queue', $1)", tableID) + require.NoError(t, err) + + t.Status("running queue feed queries") + + ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + + g, ctx := errgroup.WithContext(ctx) + + const numReaders = 10 + counters := make([]*atomic.Int64, numReaders) + for i := range counters { + counters[i] = &atomic.Int64{} + } + + g.Go(func() error { + return c.RunE(ctx, option.WithNodes(c.WorkloadNode()), + "./cockroach workload run kv --duration=10m {pgurl:1}") + }) + + g.Go(func() error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + lastCounts := make([]int64, numReaders) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + qps := make([]string, numReaders) + for i := 0; i < numReaders; i++ { + currentCount := counters[i].Load() + ratePerSec := currentCount - lastCounts[i] + qps[i] = fmt.Sprintf("%d", ratePerSec) + lastCounts[i] = currentCount + } + t.L().Printf("qps: %s", strings.Join(qps, ",")) + } + } + }) + + dbNodes := 1 // TODO fix bug that occurs with 3 + nodePool := make([]*sql.DB, numReaders) + for i := range dbNodes { + nodePool[i] = c.Conn(ctx, t.L(), i+1) + } + defer func() { + for i := range dbNodes { + _ = nodePool[i].Close() + } + }() + + for i := 0; i < numReaders; i++ { + readerIndex := i + g.Go(func() error { + // Stagger the readers a bit. This helps test re-distribution of + // partitions. + // TODO fix bug that occurs with jitter + // time.Sleep(time.Duration(rand.Intn(int(time.Minute)))) + + // Connect to a random node to simulate a tcp load balancer. + conn, err := nodePool[rand.Intn(dbNodes)].Conn(ctx) + if err != nil { + return errors.Wrap(err, "getting connection for the queuefeed reader") + } + defer func() { _ = conn.Close() }() + + for ctx.Err() == nil { + var count int + err := conn.QueryRowContext(ctx, + "SELECT count(*) FROM crdb_internal.select_from_queue_feed('kv_queue', 10000)").Scan(&count) + if err != nil { + return err + } + counters[readerIndex].Add(int64(count)) + } + return ctx.Err() + }) + } + + err = g.Wait() + if err != nil && ctx.Err() == nil { + t.Fatal(err) + } +} diff --git a/pkg/cmd/roachtest/tests/registry.go b/pkg/cmd/roachtest/tests/registry.go index 400d25fa4234..cec911c3bc66 100644 --- a/pkg/cmd/roachtest/tests/registry.go +++ b/pkg/cmd/roachtest/tests/registry.go @@ -131,6 +131,7 @@ func RegisterTests(r registry.Registry) { registerPruneDanglingSnapshotsAndDisks(r) registerPTP(r) registerQueue(r) + registerQueuefeed(r) registerQuitTransfersLeases(r) registerRebalanceLoad(r) registerReplicaGC(r) diff --git a/pkg/kv/txn.go b/pkg/kv/txn.go index c1f6bd824293..c3978bfc8797 100644 --- a/pkg/kv/txn.go +++ b/pkg/kv/txn.go @@ -90,6 +90,9 @@ type Txn struct { // commitTriggers are run upon successful commit. commitTriggers []func(ctx context.Context) + // rollbackTriggers are run upon rollback/abort. + rollbackTriggers []func(ctx context.Context) + // mu holds fields that need to be synchronized for concurrent request execution. mu struct { syncutil.Mutex @@ -1093,6 +1096,16 @@ func (txn *Txn) AddCommitTrigger(trigger func(ctx context.Context)) { txn.commitTriggers = append(txn.commitTriggers, trigger) } +// AddRollbackTrigger adds a closure to be executed on rollback/abort +// of the transaction. +func (txn *Txn) AddRollbackTrigger(trigger func(ctx context.Context)) { + if txn.typ != RootTxn { + panic(errors.AssertionFailedf("AddRollbackTrigger() called on leaf txn")) + } + + txn.rollbackTriggers = append(txn.rollbackTriggers, trigger) +} + // endTxnReqAlloc is used to batch the heap allocations of an EndTxn request. type endTxnReqAlloc struct { req kvpb.EndTxnRequest @@ -1243,6 +1256,9 @@ func (txn *Txn) PrepareForRetry(ctx context.Context) error { // Reset commit triggers. These must be reconfigured by the client during the // next retry. txn.commitTriggers = nil + // Reset rollback triggers. These must be reconfigured by the client during the + // next retry. + txn.rollbackTriggers = nil txn.mu.Lock() defer txn.mu.Unlock() @@ -1383,9 +1399,17 @@ func (txn *Txn) Send( if pErr == nil { // Invoking the commit triggers here ensures they run even in the case when a // commit request is issued manually (not via Commit). - if et, ok := ba.GetArg(kvpb.EndTxn); ok && et.(*kvpb.EndTxnRequest).Commit { - for _, t := range txn.commitTriggers { - t(ctx) + if et, ok := ba.GetArg(kvpb.EndTxn); ok { + if et.(*kvpb.EndTxnRequest).Commit { + for _, t := range txn.commitTriggers { + t(ctx) + } + } else { + // Invoking the rollback triggers here ensures they run even in the case when a + // rollback request is issued manually (not via Rollback). + for _, t := range txn.rollbackTriggers { + t(ctx) + } } } return br, nil diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 3d489bd7eac0..b7ff3731dc43 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -259,6 +259,7 @@ go_library( "//pkg/sql/physicalplan", "//pkg/sql/privilege", "//pkg/sql/querycache", + "//pkg/sql/queuefeed", "//pkg/sql/rangeprober", "//pkg/sql/regions", "//pkg/sql/rolemembershipcache", diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index 11cb4cc2cdbd..daa2f120a613 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -89,6 +89,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/optionalnodeliveness" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/querycache" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed" "github.com/cockroachdb/cockroach/pkg/sql/rangeprober" "github.com/cockroachdb/cockroach/pkg/sql/regions" "github.com/cockroachdb/cockroach/pkg/sql/rolemembershipcache" @@ -1062,6 +1063,7 @@ func newSQLServer(ctx context.Context, cfg sqlServerArgs) (*SQLServer, error) { TenantReadOnly: cfg.SQLConfig.TenantReadOnly, CidrLookup: cfg.BaseConfig.CidrLookup, LicenseEnforcer: cfg.SQLConfig.LicenseEnforcer, + QueueManager: queuefeed.NewManager(ctx, cfg.internalDB, cfg.rangeFeedFactory, cfg.rangeDescIteratorFactory, codec, leaseMgr, cfg.sqlLivenessProvider.CachedReader()), } if codec.ForSystemTenant() { @@ -1791,6 +1793,11 @@ func (s *SQLServer) preStart( s.startLicenseEnforcer(ctx, knobs) + // Close queue manager when the stopper stops. + stopper.AddCloser(stop.CloserFn(func() { + s.execCfg.QueueManager.Close() + })) + // Report a warning if the server is being shut down via the stopper // before it was gracefully drained. This warning may be innocuous // in tests where there is no use of the test server/cluster after diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 345ac4338c76..018afe3a1e7c 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -478,6 +478,8 @@ go_library( "//pkg/sql/privilege", "//pkg/sql/protoreflect", "//pkg/sql/querycache", + "//pkg/sql/queuefeed", + "//pkg/sql/queuefeed/queuebase", "//pkg/sql/regionliveness", "//pkg/sql/regions", "//pkg/sql/rolemembershipcache", diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 5d041d516b31..a8d8f3f656b9 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -54,6 +54,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirecancel" "github.com/cockroachdb/cockroach/pkg/sql/prep" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/regions" "github.com/cockroachdb/cockroach/pkg/sql/schemachanger/scerrors" "github.com/cockroachdb/cockroach/pkg/sql/sem/asof" @@ -65,6 +67,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessionmutator" "github.com/cockroachdb/cockroach/pkg/sql/sessionphase" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" "github.com/cockroachdb/cockroach/pkg/sql/sqlstats" "github.com/cockroachdb/cockroach/pkg/sql/sqlstats/insights" "github.com/cockroachdb/cockroach/pkg/sql/sqlstats/persistedsqlstats" @@ -91,6 +94,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tochar" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/crlib/crtime" "github.com/cockroachdb/errors" "github.com/cockroachdb/logtags" @@ -1206,6 +1210,7 @@ func (s *Server) newConnExecutor( totalActiveTimeStopWatch: timeutil.NewStopWatch(), txnFingerprintIDCache: NewTxnFingerprintIDCache(ctx, s.cfg.Settings, &txnFingerprintIDCacheAcc), txnFingerprintIDAcc: &txnFingerprintIDCacheAcc, + queuefeedReaders: make(map[string]*queuefeed.Reader), } ex.rng.internal = rand.New(rand.NewSource(timeutil.Now().UnixNano())) @@ -1409,6 +1414,14 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) { ex.state.finishExternalTxn() } + // Close all queuefeed readers + for name, reader := range ex.queuefeedReaders { + if err := reader.Close(); err != nil { + log.Dev.Warningf(ctx, "error closing queuefeed reader %s: %v", name, err) + } + } + ex.queuefeedReaders = nil + ex.resetExtraTxnState(ctx, txnEvent{eventType: txnEvType}, payloadErr) if ex.hasCreatedTemporarySchema && !ex.server.cfg.TestingKnobs.DisableTempObjectsCleanupOnSessionExit { err := cleanupSessionTempObjects( @@ -1902,6 +1915,10 @@ type connExecutor struct { // PCR reader catalog, which is done by checking for the ReplicatedPCRVersion // field on the system database (which is set during tenant bootstrap). isPCRReaderCatalog bool + + // queuefeedReaders stores queuefeed readers created for this connection. + // Readers are closed when the connection closes. + queuefeedReaders map[string]*queuefeed.Reader } // ctxHolder contains a connection's context and, while session tracing is @@ -3815,6 +3832,7 @@ func bufferedWritesIsAllowedForIsolationLevel( func (ex *connExecutor) initEvalCtx(ctx context.Context, evalCtx *extendedEvalContext, p *planner) { *evalCtx = extendedEvalContext{ Context: eval.Context{ + SessionCtx: ex.ctxHolder.ctx(), Planner: p, StreamManagerFactory: p, PrivilegedAccessor: p, @@ -3857,6 +3875,7 @@ func (ex *connExecutor) initEvalCtx(ctx context.Context, evalCtx *extendedEvalCo localSQLStats: ex.server.localSqlStats, indexUsageStats: ex.indexUsageStats, statementPreparer: ex, + QueueReaderProvider: ex, } evalCtx.copyFromExecCfg(ex.server.cfg) } @@ -3910,6 +3929,8 @@ func (ex *connExecutor) GetPCRReaderTimestamp() hlc.Timestamp { // Safe for concurrent use. func (ex *connExecutor) resetEvalCtx(evalCtx *extendedEvalContext, txn *kv.Txn, stmtTS time.Time) { newTxn := txn == nil || evalCtx.Txn != txn + // Keep the session context up to date (accounts for session tracing hijack). + evalCtx.SessionCtx = ex.ctxHolder.ctx() evalCtx.TxnState = ex.getTransactionState() evalCtx.TxnReadOnly = ex.state.readOnly.Load() evalCtx.TxnImplicit = ex.implicitTxn() @@ -4575,6 +4596,49 @@ func (ex *connExecutor) getCreatedSequencesAccessor() createdSequences { } } +// GetOrInitReader gets or creates a queuefeed reader for the given queue name. +// Readers are stored per-connection and closed when the connection closes. +func (ex *connExecutor) GetOrInitReader(ctx context.Context, name string) (queuebase.Reader, error) { + if reader, ok := ex.queuefeedReaders[name]; ok && reader.IsAlive() { + return reader, nil + } + + if ex.server.cfg.QueueManager == nil { + return nil, errors.New("queue manager not configured") + } + mgr := ex.server.cfg.QueueManager + + // Construct Session. + sessionID := ex.planner.extendedEvalCtx.SessionID + connectionIDBytes := sessionID.GetBytes() + connectionID, err := uuid.FromBytes(connectionIDBytes) + if err != nil { + return nil, errors.Wrapf(err, "converting session ID to UUID") + } + + var livenessID sqlliveness.SessionID + if ex.server.cfg.SQLLiveness != nil { + session, err := ex.server.cfg.SQLLiveness.Session(ex.Ctx()) + if err != nil { + return nil, errors.Wrapf(err, "getting sqlliveness session") + } + if session == nil { + return nil, errors.New("sqlliveness session is nil") + } + livenessID = session.ID() + } + + session := queuefeed.Session{ConnectionID: connectionID, LivenessID: livenessID} + + reader, err := mgr.CreateReaderForSession(ctx, name, session) + if err != nil { + return nil, errors.Wrapf(err, "creating reader for queue %s", name) + } + + ex.queuefeedReaders[name] = reader + return reader, nil +} + // sessionEventf logs a message to the session event log (if any). func (ex *connExecutor) sessionEventf(ctx context.Context, format string, args ...interface{}) { if log.ExpensiveLogEnabled(ctx, 2) { diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 06e6a3cf216f..dc587401fc38 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -81,6 +81,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/physicalplan" plpgsqlparser "github.com/cockroachdb/cockroach/pkg/sql/plpgsql/parser" "github.com/cockroachdb/cockroach/pkg/sql/querycache" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/rolemembershipcache" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/rowinfra" @@ -1851,6 +1853,12 @@ type ExecutorConfig struct { // LicenseEnforcer is used to enforce the license profiles. LicenseEnforcer *license.Enforcer + + QueueManager *queuefeed.Manager +} + +func (cfg *ExecutorConfig) GetQueueManager() queuebase.Manager { + return cfg.QueueManager } // UpdateVersionSystemSettingHook provides a callback that allows us diff --git a/pkg/sql/faketreeeval/BUILD.bazel b/pkg/sql/faketreeeval/BUILD.bazel index d8a4c36dade1..4f65a18a7150 100644 --- a/pkg/sql/faketreeeval/BUILD.bazel +++ b/pkg/sql/faketreeeval/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//pkg/sql/pgwire/pgerror", "//pkg/sql/pgwire/pgnotice", "//pkg/sql/privilege", + "//pkg/sql/queuefeed/queuebase", "//pkg/sql/sem/eval", "//pkg/sql/sem/tree", "//pkg/sql/sessiondata", diff --git a/pkg/sql/faketreeeval/evalctx.go b/pkg/sql/faketreeeval/evalctx.go index 5795d935ee71..4094fafb148e 100644 --- a/pkg/sql/faketreeeval/evalctx.go +++ b/pkg/sql/faketreeeval/evalctx.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -365,6 +366,11 @@ func (p *DummyEvalPlanner) ExtendHistoryRetention(ctx context.Context, id jobspb return errors.WithStack(errEvalPlanner) } +// GetQueueReaderProvider is part of the eval.Planner interface. +func (*DummyEvalPlanner) GetQueueReaderProvider() queuebase.ReaderProvider { + return nil +} + var _ eval.Planner = &DummyEvalPlanner{} var errEvalPlanner = pgerror.New(pgcode.ScalarOperationCannotRunWithoutFullSessionContext, diff --git a/pkg/sql/planner.go b/pkg/sql/planner.go index 2bd5d90bcb58..a15543327853 100644 --- a/pkg/sql/planner.go +++ b/pkg/sql/planner.go @@ -42,6 +42,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/prep" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/querycache" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/regions" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/eval" @@ -117,6 +118,9 @@ type extendedEvalContext struct { // validateDbZoneConfig should the DB zone config on commit. validateDbZoneConfig *bool + + // QueueReaderProvider provides access to queuefeed readers for this session. + QueueReaderProvider queuebase.ReaderProvider } // copyFromExecCfg copies relevant fields from an ExecutorConfig. @@ -613,6 +617,11 @@ func (p *planner) Mon() *mon.BytesMonitor { return p.monitor } +// GetQueueReaderProvider is part of the eval.Planner interface. +func (p *planner) GetQueueReaderProvider() queuebase.ReaderProvider { + return p.extendedEvalCtx.QueueReaderProvider +} + // ExecCfg implements the PlanHookState interface. func (p *planner) ExecCfg() *ExecutorConfig { return p.extendedEvalCtx.ExecCfg diff --git a/pkg/sql/queuefeed/BUILD.bazel b/pkg/sql/queuefeed/BUILD.bazel new file mode 100644 index 000000000000..646a32eb6859 --- /dev/null +++ b/pkg/sql/queuefeed/BUILD.bazel @@ -0,0 +1,86 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "queuefeed", + srcs = [ + "assignments.go", + "manager.go", + "partition_cache.go", + "partitions.go", + "reader.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/sql/queuefeed", + visibility = ["//visibility:public"], + deps = [ + "//pkg/ccl/changefeedccl/changefeedbase", + "//pkg/keys", + "//pkg/kv/kvclient/rangecache", + "//pkg/kv/kvclient/rangefeed", + "//pkg/kv/kvpb", + "//pkg/roachpb", + "//pkg/sql/catalog", + "//pkg/sql/catalog/descpb", + "//pkg/sql/catalog/fetchpb", + "//pkg/sql/catalog/lease", + "//pkg/sql/isql", + "//pkg/sql/queuefeed/queuebase", + "//pkg/sql/row", + "//pkg/sql/rowenc", + "//pkg/sql/sem/tree", + "//pkg/sql/sessiondata", + "//pkg/sql/sqlliveness", + "//pkg/sql/types", + "//pkg/util", + "//pkg/util/hlc", + "//pkg/util/log", + "//pkg/util/rangedesc", + "//pkg/util/span", + "//pkg/util/syncutil", + "//pkg/util/uuid", + "@com_github_cockroachdb_errors//:errors", + ], +) + +go_test( + name = "queuefeed_test", + srcs = [ + "assignments_test.go", + "main_test.go", + "manager_test.go", + "partition_cache_test.go", + "partitions_test.go", + "reader_test.go", + "smoke_test.go", + ], + embed = [":queuefeed"], + shard_count = 4, + deps = [ + "//pkg/base", + "//pkg/kv/kvclient/rangefeed", + "//pkg/roachpb", + "//pkg/security/securityassets", + "//pkg/security/securitytest", + "//pkg/server", + "//pkg/sql", + "//pkg/sql/catalog/lease", + "//pkg/sql/isql", + "//pkg/sql/queuefeed/queuebase", + "//pkg/sql/sem/tree", + "//pkg/sql/sqlliveness", + "//pkg/sql/sqlliveness/slstorage", + "//pkg/testutils", + "//pkg/testutils/serverutils", + "//pkg/testutils/sqlutils", + "//pkg/testutils/testcluster", + "//pkg/util/hlc", + "//pkg/util/leaktest", + "//pkg/util/log", + "//pkg/util/randutil", + "//pkg/util/rangedesc", + "//pkg/util/uuid", + "@com_github_cockroachdb_errors//:errors", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_x_sync//errgroup", + ], +) diff --git a/pkg/sql/queuefeed/assignments.go b/pkg/sql/queuefeed/assignments.go new file mode 100644 index 000000000000..ca302392637c --- /dev/null +++ b/pkg/sql/queuefeed/assignments.go @@ -0,0 +1,356 @@ +package queuefeed + +import ( + "context" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" +) + +type Assignment struct { + // Version is unique per process level and can be used to efficiently detect + // assignment changes. + Version int64 + Session Session + // Partitions is the list of partitions assigned to the session. It is sorted + // by ID. + Partitions []Partition +} + +func (a *Assignment) Spans() []roachpb.Span { + sg := roachpb.SpanGroup{} + for _, partition := range a.Partitions { + sg.Add(partition.Span) + } + return sg.Slice() +} + +type PartitionAssignments struct { + db isql.DB + partitionTable *partitionTable + + refresh struct { + // Lock ordering: refresh may be locked before mu + lastRefresh time.Time + syncutil.Mutex + } + + mu struct { + syncutil.Mutex + cache partitionCache + } +} + +func NewPartitionAssignments(db isql.DB, queueName string) (*PartitionAssignments, error) { + pa := &PartitionAssignments{ + db: db, + partitionTable: &partitionTable{queueName: queueName}, + } + + var partitions []Partition + err := db.Txn(context.Background(), func(ctx context.Context, txn isql.Txn) error { + var err error + partitions, err = pa.partitionTable.ListPartitions(ctx, txn) + return err + }) + if err != nil { + return nil, errors.Wrap(err, "unable to load initial partitions") + } + + pa.mu.cache.Init(partitions) + pa.refresh.lastRefresh = time.Now() + + return pa, nil +} + +func (p *PartitionAssignments) maybeRefreshCache() error { + // TODO handle deletions + // TODO add a version mechanism to avoid races between write through updates and refereshes + // TODO use a rangefeed instead of polling + + p.refresh.Lock() + defer p.refresh.Unlock() + + if time.Since(p.refresh.lastRefresh) < 5*time.Second { + return nil + } + + var partitions []Partition + err := p.db.Txn(context.Background(), func(ctx context.Context, txn isql.Txn) error { + var err error + partitions, err = p.partitionTable.ListPartitions(ctx, txn) + return err + }) + if err != nil { + return err + } + + updates := make(map[int64]Partition) + for _, partition := range partitions { + updates[partition.ID] = partition + } + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.cache.Update(updates) + p.refresh.lastRefresh = time.Now() + return nil +} + +// RegisterSession registers a new session. The session may be assigned zero +// partitions if there are no unassigned partitions. If it is assigned no +// partitions, the caller can periodically call RefreshAssignment claim +// partitions if they become available. +func (p *PartitionAssignments) RegisterSession( + ctx context.Context, session Session, +) (*Assignment, error) { + if err := p.maybeRefreshCache(); err != nil { + return nil, errors.Wrap(err, "refreshing partition cache") + } + + var err error + var done bool + for !done { + tryClaim, trySteal := func() (Partition, Partition) { + p.mu.Lock() + defer p.mu.Unlock() + return p.mu.cache.planRegister(session, p.mu.cache) + }() + switch { + case !tryClaim.Empty(): + err, done = p.tryClaim(session, tryClaim) + if err != nil { + return nil, errors.Wrap(err, "claiming partition") + } + case !trySteal.Empty(): + err, done = p.trySteal(session, trySteal) + if err != nil { + return nil, errors.Wrap(err, "stealing partition") + } + default: + done = true + } + } + return p.constructAssignment(session), nil +} + +func (p *PartitionAssignments) tryClaim(session Session, toClaim Partition) (error, bool) { + var updates map[int64]Partition + var done bool + err := p.db.Txn(context.Background(), func(ctx context.Context, txn isql.Txn) error { + done, updates = false, nil + + var err error + updates, err = p.anyStale(ctx, txn, []Partition{toClaim}) + if err != nil || len(updates) != 0 { + return err + } + + updates = make(map[int64]Partition) + toClaim.Session = session + updates[toClaim.ID] = toClaim + if err := p.partitionTable.UpdatePartition(ctx, txn, toClaim); err != nil { + return err + } + + done = true + return nil + }) + if err != nil { + return err, false + } + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.cache.Update(updates) + return nil, done +} + +func (p *PartitionAssignments) trySteal(session Session, toSteal Partition) (error, bool) { + var updates map[int64]Partition + var done bool + err := p.db.Txn(context.Background(), func(ctx context.Context, txn isql.Txn) error { + done, updates = false, nil + + var err error + updates, err = p.anyStale(ctx, txn, []Partition{toSteal}) + if err != nil || len(updates) != 0 { + return err + } + + updates = make(map[int64]Partition) + toSteal.Successor = session + updates[toSteal.ID] = toSteal + if err := p.partitionTable.UpdatePartition(ctx, txn, toSteal); err != nil { + return err + } + + done = true + return nil + }) + if err != nil { + return err, false + } + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.cache.Update(updates) + return nil, done +} + +// RefreshAssignment refreshes the assignment for the given session. It returnrns +// nil if the assignment has not changed. +// +// If the session is caught up (i.e. it has proceessed up to a recent timestamp +// for all assigned partitions), then it may be assigned new partitions. +// +// If a partition has a successor session, then calling RefreshAssignment will +// return an assignment that does not include that partition. +func (p *PartitionAssignments) RefreshAssignment( + ctx context.Context, assignment *Assignment, caughtUp bool, +) (*Assignment, error) { + if err := p.maybeRefreshCache(); err != nil { + return nil, errors.Wrap(err, "refreshing partition cache") + } + + var done bool + var err error + for !done { + tryRelease, tryClaim, trySteal := func() ([]Partition, Partition, Partition) { + p.mu.Lock() + defer p.mu.Unlock() + return p.mu.cache.planAssignment(assignment.Session, caughtUp, p.mu.cache) + }() + switch { + case len(tryRelease) != 0: + err, done = p.tryRelease(assignment.Session, tryRelease) + if err != nil { + return nil, errors.Wrap(err, "releasing partition") + } + case !tryClaim.Empty(): + err, done = p.tryClaim(assignment.Session, tryClaim) + if err != nil { + return nil, errors.Wrap(err, "claiming partition") + } + case !trySteal.Empty(): + err, done = p.trySteal(assignment.Session, trySteal) + if err != nil { + return nil, errors.Wrap(err, "stealing partition") + } + default: + stale := func() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.mu.cache.isStale(assignment) + }() + if !stale { + return nil, nil + } + done = true + } + } + if err != nil { + return nil, err + } + return p.constructAssignment(assignment.Session), nil +} + +// anyStale checks if any of the provided partitions have become stale by +// comparing them with the current state in the database. Returns a map of +// partition ID to the updated partition that can be applied to the cache. +func (p *PartitionAssignments) anyStale( + ctx context.Context, txn isql.Txn, partitions []Partition, +) (map[int64]Partition, error) { + if len(partitions) == 0 { + return make(map[int64]Partition), nil + } + + // Extract partition IDs + partitionIDs := make([]int64, len(partitions)) + for i, partition := range partitions { + partitionIDs[i] = partition.ID + } + + // Fetch current state from database + currentPartitions, err := p.partitionTable.FetchPartitions(ctx, txn, partitionIDs) + if err != nil { + return nil, err + } + + // Compare cached vs current state and collect stale partitions + stalePartitions := make(map[int64]Partition) + for _, cachedPartition := range partitions { + currentPartition := currentPartitions[cachedPartition.ID] + + // If partition was deleted from database, mark it as empty in updates + if currentPartition.Empty() { + stalePartitions[cachedPartition.ID] = Partition{} + } else if !cachedPartition.Equal(currentPartition) { + // If partition has changed, include the updated version + stalePartitions[cachedPartition.ID] = currentPartition + } + } + + return stalePartitions, nil +} + +func (p *PartitionAssignments) tryRelease(session Session, toRelease []Partition) (error, bool) { + var updates map[int64]Partition + var done bool + err := p.db.Txn(context.Background(), func(ctx context.Context, txn isql.Txn) error { + done, updates = false, nil + + var err error + updates, err = p.anyStale(ctx, txn, toRelease) + if err != nil || len(updates) != 0 { + return err + } + + updates = make(map[int64]Partition) + for _, partition := range toRelease { + partition.Session = partition.Successor + updates[partition.ID] = partition + if err := p.partitionTable.UpdatePartition(ctx, txn, partition); err != nil { + return err + } + } + + done = true + return nil + }) + if err != nil { + return err, false + } + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.cache.Update(updates) + return nil, done +} + +func (p *PartitionAssignments) constructAssignment(session Session) *Assignment { + p.mu.Lock() + defer p.mu.Unlock() + return p.mu.cache.constructAssignment(session) +} + +func (p *PartitionAssignments) UnregisterSession(ctx context.Context, session Session) error { + var updates map[int64]Partition + err := p.db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + var err error + updates, err = p.partitionTable.UnregisterSession(ctx, txn, session) + return err + }) + if err != nil { + return err + } + + p.mu.Lock() + defer p.mu.Unlock() + p.mu.cache.Update(updates) + + return nil +} diff --git a/pkg/sql/queuefeed/assignments_test.go b/pkg/sql/queuefeed/assignments_test.go new file mode 100644 index 000000000000..9655f46f00db --- /dev/null +++ b/pkg/sql/queuefeed/assignments_test.go @@ -0,0 +1,66 @@ +package queuefeed_test + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/sql" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +func TestPartitionAssignments(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer s.Stopper().Stop(ctx) + + tdb := sqlutils.MakeSQLRunner(sqlDB) + tdb.Exec(t, "CREATE TABLE test_table (id INT PRIMARY KEY, data TEXT)") + + var tableDescID int64 + tdb.QueryRow(t, "SELECT id FROM system.namespace WHERE name = 'test_table'").Scan(&tableDescID) + + // Create queue using QueueManager + manager := queuefeed.NewTestManager(t, s.ApplicationLayer()) + defer manager.Close() + queueName := "test_queue" + err := manager.CreateQueue(ctx, queueName, tableDescID) + require.NoError(t, err) + + pa, err := queuefeed.NewPartitionAssignments(s.ExecutorConfig().(sql.ExecutorConfig).InternalDB, queueName) + require.NoError(t, err) + + session := queuefeed.Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID("1"), + } + + assignment, err := pa.RegisterSession(ctx, session) + require.NoError(t, err) + require.Len(t, assignment.Partitions, 1) + require.Equal(t, session, assignment.Partitions[0].Session, "partition: %+v", assignment.Partitions[0]) + + tdb.CheckQueryResults(t, + "SELECT sql_liveness_session, user_session FROM defaultdb.queue_partition_"+queueName, + [][]string{{"1", session.ConnectionID.String()}}) + + newAssignment, err := pa.RefreshAssignment(context.Background(), assignment, true) + require.NoError(t, err) + require.Nil(t, newAssignment) + + require.NoError(t, pa.UnregisterSession(ctx, session)) + + tdb.CheckQueryResults(t, + "SELECT sql_liveness_session, user_session FROM defaultdb.queue_partition_"+queueName, + [][]string{{"NULL", "NULL"}}) +} diff --git a/pkg/sql/queuefeed/main_test.go b/pkg/sql/queuefeed/main_test.go new file mode 100644 index 000000000000..a0065c2e5693 --- /dev/null +++ b/pkg/sql/queuefeed/main_test.go @@ -0,0 +1,26 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package queuefeed_test + +import ( + "os" + "testing" + + "github.com/cockroachdb/cockroach/pkg/security/securityassets" + "github.com/cockroachdb/cockroach/pkg/security/securitytest" + "github.com/cockroachdb/cockroach/pkg/server" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/testcluster" + "github.com/cockroachdb/cockroach/pkg/util/randutil" +) + +func TestMain(m *testing.M) { + securityassets.SetLoader(securitytest.EmbeddedAssets) + randutil.SeedForTests() + serverutils.InitTestServerFactory(server.TestServerFactory) + serverutils.InitTestClusterFactory(testcluster.TestClusterFactory) + os.Exit(m.Run()) +} diff --git a/pkg/sql/queuefeed/manager.go b/pkg/sql/queuefeed/manager.go new file mode 100644 index 000000000000..0c87823c9af3 --- /dev/null +++ b/pkg/sql/queuefeed/manager.go @@ -0,0 +1,462 @@ +// queuefeed is a somthing +package queuefeed + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangecache" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangefeed" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/lease" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/rangedesc" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" +) + +// watch queue partition table +// and create it too?? +type Manager struct { + executor isql.DB + rff *rangefeed.Factory + rdi rangedesc.IteratorFactory + rc *rangecache.RangeCache + codec keys.SQLCodec + leaseMgr *lease.Manager + sqlLivenessReader sqlliveness.Reader + + mu struct { + syncutil.Mutex + queueAssignment map[string]*PartitionAssignments + } + + // watchCtx and watchCancel are used to control the watchForDeadSessions goroutine. + watchCtx context.Context + watchCancel context.CancelFunc + watchWg sync.WaitGroup +} + +func NewManager( + ctx context.Context, + executor isql.DB, + rff *rangefeed.Factory, + rdi rangedesc.IteratorFactory, + codec keys.SQLCodec, + leaseMgr *lease.Manager, + sqlLivenessReader sqlliveness.Reader, +) *Manager { + // setup rangefeed on partitions table (/poll) + // handle handoff from one server to another + watchCtx, watchCancel := context.WithCancel(ctx) + m := &Manager{ + executor: executor, + rff: rff, + rdi: rdi, + codec: codec, + leaseMgr: leaseMgr, + sqlLivenessReader: sqlLivenessReader, + watchCtx: watchCtx, + watchCancel: watchCancel, + } + m.mu.queueAssignment = make(map[string]*PartitionAssignments) + + m.watchWg.Add(1) + go func() { + defer m.watchWg.Done() + m.watchForDeadSessions(watchCtx) + }() + + return m +} + +const createQueueCursorTableSQL = ` +CREATE TABLE IF NOT EXISTS defaultdb.queue_cursor_%s ( + partition_id INT8 PRIMARY KEY, + updated_at TIMESTAMPTZ, + cursor bytea +)` + +const createQueueTableSQL = ` +CREATE TABLE IF NOT EXISTS defaultdb.queue_feeds ( + queue_feed_name STRING PRIMARY KEY, + table_desc_id INT8 NOT NULL +)` + +const insertQueueFeedSQL = ` +INSERT INTO defaultdb.queue_feeds (queue_feed_name, table_desc_id) VALUES ($1, $2) +` + +const fetchQueueFeedSQL = ` +SELECT table_desc_id FROM defaultdb.queue_feeds WHERE queue_feed_name = $1 +` + +const updateCheckpointSQL = ` +UPSERT INTO defaultdb.queue_cursor_%s (partition_id, updated_at, cursor) +VALUES ($1, now(), $2) +` + +const readCheckpointSQL = ` +SELECT cursor FROM defaultdb.queue_cursor_%s WHERE partition_id = $1 +` + +// should take a txn +func (m *Manager) CreateQueue(ctx context.Context, queueName string, tableDescID int64) error { + return m.CreateQueueFromCursor(ctx, queueName, tableDescID, hlc.Timestamp{}) +} +func (m *Manager) CreateQueueFromCursor(ctx context.Context, queueName string, tableDescID int64, cursor hlc.Timestamp) error { + err := m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + _, err := txn.Exec(ctx, "create_q", txn.KV(), createQueueTableSQL) + if err != nil { + return err + } + + pt := &partitionTable{queueName: queueName} + err = pt.CreateSchema(ctx, txn) + if err != nil { + return err + } + + _, err = txn.Exec(ctx, "create_qc", txn.KV(), fmt.Sprintf(createQueueCursorTableSQL, queueName)) + if err != nil { + return err + } + return nil + }) + if err != nil { + return errors.Wrapf(err, "creating queue tables for %s", queueName) + } + + return m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + // TODO(queuefeed): figure out how we want to integrate with schema changes. + descriptor, err := m.leaseMgr.Acquire(ctx, lease.TimestampToReadTimestamp(txn.KV().ReadTimestamp()), descpb.ID(tableDescID)) + if err != nil { + return err + } + tableDesc := descriptor.Underlying().(catalog.TableDescriptor) + defer descriptor.Release(ctx) + + _, err = txn.Exec(ctx, "insert_q", txn.KV(), insertQueueFeedSQL, queueName, tableDescID) + if err != nil { + return err + } + + pt := &partitionTable{queueName: queueName} + + // Make a partition for each range of the table's primary span, covering the span of that range. + primaryIndexPrefix := m.codec.IndexPrefix(uint32(tableDesc.GetID()), uint32(tableDesc.GetPrimaryIndexID())) + primaryKeySpan := roachpb.Span{ + Key: primaryIndexPrefix, + EndKey: primaryIndexPrefix.PrefixEnd(), + } + + spans, err := m.splitOnRanges(ctx, primaryKeySpan) + if err != nil { + return err + } + + partitionID := int64(1) + for _, span := range spans { + partition := Partition{ + ID: partitionID, + Span: span, + } + + if err := pt.InsertPartition(ctx, txn, partition); err != nil { + return errors.Wrapf(err, "inserting partition %d for range", partitionID) + } + + checkpointTS := txn.KV().ReadTimestamp() + if !cursor.IsEmpty() { + checkpointTS = cursor + } + // checkpoint the partition at the transaction timestamp + err = m.WriteCheckpoint(ctx, queueName, partitionID, checkpointTS) + if err != nil { + return errors.Wrapf(err, "writing checkpoint for partition %d", partitionID) + } + + partitionID++ + } + + return nil + }) +} + +func (m *Manager) newReaderLocked( + ctx context.Context, name string, session Session, +) (*Reader, error) { + var tableDescID int64 + + // TODO: this ctx on the other hand should be stmt scoped + err := m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + _, err := txn.Exec(ctx, "create_q", txn.KV(), createQueueTableSQL) + if err != nil { + return err + } + + vals, err := txn.QueryRowEx(ctx, "fetch_q", txn.KV(), + sessiondata.NodeUserSessionDataOverride, fetchQueueFeedSQL, name) + if err != nil { + return err + } + if len(vals) == 0 { + return errors.Errorf("queue feed not found") + } + tableDescID = int64(tree.MustBeDInt(vals[0])) + return nil + }) + if err != nil { + return nil, err + } + + assigner, ok := m.mu.queueAssignment[name] + if !ok { + var err error + assigner, err = NewPartitionAssignments(m.executor, name) + if err != nil { + return nil, err + } + m.mu.queueAssignment[name] = assigner + } + + fmt.Printf("get or init reader for queue %s with table desc id: %d\n", name, tableDescID) + return NewReader(ctx, m.executor, m, m.rff, m.codec, m.leaseMgr, session, assigner, name) +} + +func (m *Manager) reassessAssignments(ctx context.Context, name string) (bool, error) { + return false, nil +} + +// CreateReaderForSession creates a new reader for the given queue name and session. +func (m *Manager) CreateReaderForSession( + ctx context.Context, name string, session Session, +) (*Reader, error) { + m.mu.Lock() + defer m.mu.Unlock() + return m.newReaderLocked(ctx, name, session) +} + +func (m *Manager) WriteCheckpoint( + ctx context.Context, queueName string, partitionID int64, ts hlc.Timestamp, +) error { + // Serialize the timestamp as bytes + cursorBytes, err := ts.Marshal() + if err != nil { + return errors.Wrap(err, "marshaling checkpoint timestamp") + } + + sql := fmt.Sprintf(updateCheckpointSQL, queueName) + return m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + _, err := txn.Exec(ctx, "write_checkpoint", txn.KV(), sql, partitionID, cursorBytes) + return err + }) +} + +func (m *Manager) ReadCheckpoint( + ctx context.Context, queueName string, partitionID int64, +) (hlc.Timestamp, error) { + var ts hlc.Timestamp + sql := fmt.Sprintf(readCheckpointSQL, queueName) + + err := m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + row, err := txn.QueryRowEx(ctx, "read_checkpoint", txn.KV(), + sessiondata.NodeUserSessionDataOverride, sql, partitionID) + if err != nil { + return err + } + if row == nil { + return nil + } + + cursorBytes := []byte(*row[0].(*tree.DBytes)) + if err := ts.Unmarshal(cursorBytes); err != nil { + return errors.Wrap(err, "unmarshaling checkpoint timestamp") + } + return nil + }) + + return ts, err +} + +func (m *Manager) splitOnRanges(ctx context.Context, span roachpb.Span) ([]roachpb.Span, error) { + const pageSize = 100 + rdi, err := m.rdi.NewLazyIterator(ctx, span, pageSize) + if err != nil { + return nil, err + } + + var spans []roachpb.Span + remainingSpan := span + + for ; rdi.Valid(); rdi.Next() { + rangeDesc := rdi.CurRangeDescriptor() + rangeSpan := roachpb.Span{Key: rangeDesc.StartKey.AsRawKey(), EndKey: rangeDesc.EndKey.AsRawKey()} + subspan := remainingSpan.Intersect(rangeSpan) + if !subspan.Valid() { + return nil, errors.AssertionFailedf("%s not in %s of %s", rangeSpan, remainingSpan, span) + } + spans = append(spans, subspan) + remainingSpan.Key = subspan.EndKey + } + + if err := rdi.Error(); err != nil { + return nil, err + } + + if remainingSpan.Valid() { + spans = append(spans, remainingSpan) + } + + return spans, nil +} + +// A loop that looks for partitions that are assigned to sql liveness sessions +// that are no longer alive and removes all of their partition claims. (see the +// IsAlive method in the sqlliveness packages) +func (m *Manager) watchForDeadSessions(ctx context.Context) { + // Check for dead sessions every 10 seconds. + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.checkAndClearDeadSessions(ctx); err != nil { + log.Dev.Warningf(ctx, "error checking for dead sessions: %v", err) + } + } + } +} + +const listQueueFeedsSQL = `SELECT queue_feed_name FROM defaultdb.queue_feeds` + +// checkAndClearDeadSessions checks all partitions across all queues for dead sessions +// and clears their claims. +func (m *Manager) checkAndClearDeadSessions(ctx context.Context) error { + // Get all queue names. + var queueNames []string + err := m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + rows, err := txn.QueryBuffered(ctx, "list-queue-feeds", txn.KV(), listQueueFeedsSQL) + if err != nil { + return err + } + queueNames = make([]string, 0, len(rows)) + for _, row := range rows { + queueNames = append(queueNames, string(tree.MustBeDString(row[0]))) + } + return nil + }) + if err != nil { + return errors.Wrap(err, "listing queue feeds") + } + + // Check each queue for dead sessions. + for _, queueName := range queueNames { + if err := m.checkQueueForDeadSessions(ctx, queueName); err != nil { + log.Dev.Warningf(ctx, "error checking queue %s for dead sessions: %v", queueName, err) + // Continue checking other queues even if one fails. + } + } + + return nil +} + +// checkQueueForDeadSessions checks all partitions in a queue for dead sessions +// and clears their claims. +func (m *Manager) checkQueueForDeadSessions(ctx context.Context, queueName string) error { + pt := &partitionTable{queueName: queueName} + var partitionsToUpdate []Partition + + err := m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partitions, err := pt.ListPartitions(ctx, txn) + if err != nil { + return err + } + + for _, partition := range partitions { + needsUpdate := false + updatedPartition := partition + + // Check if the Session is assigned to a dead session. + if partition.Session.LivenessID != "" { + alive, err := m.sqlLivenessReader.IsAlive(ctx, partition.Session.LivenessID) + if err != nil { + // If we can't determine liveness, err on the side of caution and don't clear. + log.Dev.Warningf(ctx, "error checking liveness for session %s: %v", partition.Session.LivenessID, err) + continue + } + if !alive { + // Session is dead. Clear the claim. + // If there's a successor, promote it to Session. + if !partition.Successor.Empty() { + updatedPartition.Session = partition.Successor + updatedPartition.Successor = Session{} + } else { + updatedPartition.Session = Session{} + } + needsUpdate = true + } + } + + // Check if the Successor is assigned to a dead session. + if partition.Successor.LivenessID != "" { + alive, err := m.sqlLivenessReader.IsAlive(ctx, partition.Successor.LivenessID) + if err != nil { + log.Dev.Warningf(ctx, "error checking liveness for successor session %s: %v", partition.Successor.LivenessID, err) + continue + } + if !alive { + // Successor session is dead. Clear it. + updatedPartition.Successor = Session{} + needsUpdate = true + } + } + + if needsUpdate { + partitionsToUpdate = append(partitionsToUpdate, updatedPartition) + } + } + + return nil + }) + if err != nil { + return errors.Wrapf(err, "listing partitions for queue %s", queueName) + } + + // Update partitions that need to be cleared. + if len(partitionsToUpdate) > 0 { + return m.executor.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + for _, partition := range partitionsToUpdate { + if err := pt.UpdatePartition(ctx, txn, partition); err != nil { + return errors.Wrapf(err, "updating partition %d for queue %s", partition.ID, queueName) + } + fmt.Printf("pruning dead sessions: updated partition %d for queue %s\n", partition.ID, queueName) + } + return nil + }) + } + + return nil +} + +// Close stops the Manager and waits for all background goroutines to exit. +func (m *Manager) Close() { + m.watchCancel() + m.watchWg.Wait() +} + +var _ queuebase.Manager = &Manager{} diff --git a/pkg/sql/queuefeed/manager_test.go b/pkg/sql/queuefeed/manager_test.go new file mode 100644 index 000000000000..2013626f6ef8 --- /dev/null +++ b/pkg/sql/queuefeed/manager_test.go @@ -0,0 +1,242 @@ +package queuefeed + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangefeed" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/lease" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slstorage" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/rangedesc" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func NewTestManager(t *testing.T, a serverutils.ApplicationLayerInterface) *Manager { + db := a.InternalDB().(isql.DB) + m := NewManager(context.Background(), db, a.RangeFeedFactory().(*rangefeed.Factory), a.RangeDescIteratorFactory().(rangedesc.IteratorFactory), a.Codec(), a.LeaseManager().(*lease.Manager), nil) + require.NotNil(t, m.codec) + t.Cleanup(m.Close) + return m +} + +func TestFeedCreation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + // expect an error when trying to read from a queue that doesn't exist + qm := NewTestManager(t, srv.ApplicationLayer()) + defer qm.Close() + _, err := qm.CreateReaderForSession(context.Background(), "test", Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: "", + }) + require.ErrorContains(t, err, "queue feed not found") + + // expect no error when creating a queue + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a string)`) + // get table id + var tableID int64 + db.QueryRow(t, "SELECT id FROM system.namespace where name = 't'").Scan(&tableID) + require.NoError(t, qm.CreateQueue(context.Background(), "test", tableID)) + + // now we can read from the queue + reader, err := qm.CreateReaderForSession(context.Background(), "test", Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: "", + }) + require.NoError(t, err) + require.NotNil(t, reader) + _ = reader.Close() +} + +func TestQueuefeedCtxCancel(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a string)`) + db.Exec(t, `SELECT crdb_internal.create_queue_feed('hi', 't')`) + + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + _, err := db.DB.QueryContext(ctx, `SELECT crdb_internal.select_from_queue_feed('hi', 1)`) + require.Error(t, err) +} + +func TestWatchForDeadSessions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + // Create a fake SQL liveness storage for testing + fakeStorage := slstorage.NewFakeStorage() + + // Create a manager with the fake storage + db := srv.ApplicationLayer().InternalDB().(isql.DB) + m := NewManager( + ctx, db, srv.ApplicationLayer().RangeFeedFactory().(*rangefeed.Factory), + srv.ApplicationLayer().RangeDescIteratorFactory().(rangedesc.IteratorFactory), srv.ApplicationLayer().Codec(), srv.ApplicationLayer().LeaseManager().(*lease.Manager), + fakeStorage, + ) + + // Create a queue + sqlDB := sqlutils.MakeSQLRunner(conn) + sqlDB.Exec(t, `CREATE TABLE t (a string PRIMARY KEY)`) + var tableID int64 + sqlDB.QueryRow(t, "SELECT id FROM system.namespace where name = 't'").Scan(&tableID) + + // Create multiple ranges BEFORE creating the queue to ensure we have enough partitions + sqlDB.Exec(t, `INSERT INTO t (a) SELECT generate_series(1, 100)`) + sqlDB.Exec(t, `ALTER TABLE t SPLIT AT VALUES ('10'), ('20'), ('30'), ('40'), ('50'), ('60'), ('70'), ('80'), ('90')`) + sqlDB.Exec(t, `ALTER TABLE t SCATTER`) + + // Now create the queue - it will create partitions for all the ranges + require.NoError(t, m.CreateQueue(ctx, "test", tableID)) + + // Get partitions + pt := &partitionTable{queueName: "test"} + var partitions []Partition + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + var err error + partitions, err = pt.ListPartitions(ctx, txn) + return err + }) + require.NoError(t, err) + require.GreaterOrEqual(t, len(partitions), 4, "should have at least 4 partitions for this test") + + // Create two sessions with liveness IDs + deadSessionID := sqlliveness.SessionID("dead-session") + aliveSessionID := sqlliveness.SessionID("alive-session") + successorSessionID := sqlliveness.SessionID("successor-session") + + deadSession := Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: deadSessionID, + } + aliveSession := Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: aliveSessionID, + } + successorSession := Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: successorSessionID, + } + + // Mark sessions as alive in fake storage + clock := srv.ApplicationLayer().Clock() + expiration := clock.Now().Add(10*time.Second.Nanoseconds(), 0) + require.NoError(t, fakeStorage.Insert(ctx, deadSessionID, expiration)) + require.NoError(t, fakeStorage.Insert(ctx, aliveSessionID, expiration)) + require.NoError(t, fakeStorage.Insert(ctx, successorSessionID, expiration)) + + // Assign some partitions to the dead session + deadPartition := partitions[0] + deadPartition.Session = deadSession + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.UpdatePartition(ctx, txn, deadPartition) + }) + require.NoError(t, err) + + // Assign a partition with a dead session and a successor + deadWithSuccessorPartition := partitions[1] + deadWithSuccessorPartition.Session = deadSession + deadWithSuccessorPartition.Successor = successorSession + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.UpdatePartition(ctx, txn, deadWithSuccessorPartition) + }) + require.NoError(t, err) + + // Assign a partition with a dead successor + deadSuccessorPartition := partitions[2] + deadSuccessorPartition.Session = aliveSession + deadSuccessorPartition.Successor = deadSession + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.UpdatePartition(ctx, txn, deadSuccessorPartition) + }) + require.NoError(t, err) + + // Assign a partition to an alive session (should remain unchanged) + alivePartition := partitions[3] + alivePartition.Session = aliveSession + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.UpdatePartition(ctx, txn, alivePartition) + }) + require.NoError(t, err) + + // Mark the dead session as dead by deleting it from fake storage + require.NoError(t, fakeStorage.Delete(ctx, deadSessionID)) + + // Check for dead sessions + require.NoError(t, m.checkQueueForDeadSessions(ctx, "test")) + + // Verify results + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + updatedPartitions, err := pt.ListPartitions(ctx, txn) + require.NoError(t, err) + + // Find the partitions we updated + var foundDeadPartition, foundDeadWithSuccessorPartition, foundDeadSuccessorPartition, foundAlivePartition *Partition + for i := range updatedPartitions { + p := &updatedPartitions[i] + if p.ID == deadPartition.ID { + foundDeadPartition = p + } else if p.ID == deadWithSuccessorPartition.ID { + foundDeadWithSuccessorPartition = p + } else if p.ID == deadSuccessorPartition.ID { + foundDeadSuccessorPartition = p + } else if p.ID == alivePartition.ID { + foundAlivePartition = p + } + } + + // Dead session partition should be cleared + require.NotNil(t, foundDeadPartition, "should find dead partition") + assert.True(t, foundDeadPartition.Session.Empty(), "dead session partition should be cleared") + assert.True(t, foundDeadPartition.Successor.Empty(), "dead session partition should have no successor") + + // Dead session with successor should promote successor to session + require.NotNil(t, foundDeadWithSuccessorPartition, "should find dead with successor partition") + assert.Equal(t, successorSession, foundDeadWithSuccessorPartition.Session, "successor should be promoted to session") + assert.True(t, foundDeadWithSuccessorPartition.Successor.Empty(), "successor should be cleared") + + // Dead successor should be cleared + require.NotNil(t, foundDeadSuccessorPartition, "should find dead successor partition") + assert.Equal(t, aliveSession, foundDeadSuccessorPartition.Session, "alive session should remain") + assert.True(t, foundDeadSuccessorPartition.Successor.Empty(), "dead successor should be cleared") + + // Alive session partition should remain unchanged + require.NotNil(t, foundAlivePartition, "should find alive partition") + assert.Equal(t, aliveSession, foundAlivePartition.Session, "alive session should remain unchanged") + assert.True(t, foundAlivePartition.Successor.Empty(), "alive partition should have no successor") + + return nil + }) + require.NoError(t, err) + + // Close the manager to wait for the watchForDeadSessions goroutine to exit + m.Close() +} diff --git a/pkg/sql/queuefeed/partition_cache.go b/pkg/sql/queuefeed/partition_cache.go new file mode 100644 index 000000000000..774ce0711c95 --- /dev/null +++ b/pkg/sql/queuefeed/partition_cache.go @@ -0,0 +1,343 @@ +package queuefeed + +import ( + "fmt" + "math/rand" + "sort" + "strings" +) + +type partitionCache struct { + // partitions is a stale cache of the current state of the partition's + // table. Any assignment decisions and updates should be made using + // transactions. + partitions map[int64]Partition + + // assignmentIndex is a map of sessions to assigned partitions. + assignmentIndex map[Session]map[int64]struct{} + + // successorIndex is a map of successor sessions to partitions. + successorIndex map[Session]map[int64]struct{} + + sessions map[Session]struct{} +} + +func (p *partitionCache) DebugString() string { + var result strings.Builder + + result.WriteString("PartitionCache Debug:\n") + result.WriteString("===================\n\n") + + // Print partitions + result.WriteString("Partitions:\n") + if len(p.partitions) == 0 { + result.WriteString(" (none)\n") + } else { + for id, partition := range p.partitions { + result.WriteString(fmt.Sprintf(" ID: %d", id)) + if !partition.Session.Empty() { + result.WriteString(fmt.Sprintf(" | Session: %s", partition.Session.ConnectionID.String()[:8])) + } else { + result.WriteString(" | Session: (unassigned)") + } + result.WriteString("\n") + } + } + + // Print assignment index + result.WriteString("\nAssignment Index (session -> partitions):\n") + if len(p.assignmentIndex) == 0 { + result.WriteString(" (none)\n") + } else { + for session, partitions := range p.assignmentIndex { + result.WriteString(fmt.Sprintf(" %s: [", session.ConnectionID.String()[:8])) + partitionIDs := make([]int64, 0, len(partitions)) + for id := range partitions { + partitionIDs = append(partitionIDs, id) + } + sort.Slice(partitionIDs, func(i, j int) bool { + return partitionIDs[i] < partitionIDs[j] + }) + for i, id := range partitionIDs { + if i > 0 { + result.WriteString(", ") + } + result.WriteString(fmt.Sprintf("%d", id)) + } + result.WriteString("]\n") + } + } + + // Print successor index + result.WriteString("\nSuccessor Index (successor session -> partitions):\n") + if len(p.successorIndex) == 0 { + result.WriteString(" (none)\n") + } else { + for session, partitions := range p.successorIndex { + result.WriteString(fmt.Sprintf(" %s: [", session.ConnectionID.String()[:8])) + partitionIDs := make([]int64, 0, len(partitions)) + for id := range partitions { + partitionIDs = append(partitionIDs, id) + } + sort.Slice(partitionIDs, func(i, j int) bool { + return partitionIDs[i] < partitionIDs[j] + }) + for i, id := range partitionIDs { + if i > 0 { + result.WriteString(", ") + } + result.WriteString(fmt.Sprintf("%d", id)) + } + result.WriteString("]\n") + } + } + + return result.String() +} + +func (p *partitionCache) Init(partitions []Partition) { + p.partitions = make(map[int64]Partition) + p.assignmentIndex = make(map[Session]map[int64]struct{}) + p.successorIndex = make(map[Session]map[int64]struct{}) + + for _, partition := range partitions { + p.addPartition(partition) + } +} + +func (p *partitionCache) Update(partitions map[int64]Partition) { + // TODO(queuefeed): When we introduce rangefeeds we probably need to add mvcc + // version to Partition to make sure updates from sql statements are kept + // coherent with updates from the rangefeed. + for id, newPartition := range partitions { + oldPartition := p.partitions[id] + switch { + case newPartition.Empty(): + p.removePartition(id) + case oldPartition.Empty(): + p.addPartition(newPartition) + default: + p.updatePartition(oldPartition, newPartition) + } + } +} + +func (p *partitionCache) removePartition(partitionID int64) { + partition, exists := p.partitions[partitionID] + if !exists { + return + } + + delete(p.partitions, partitionID) + + // Remove from session index + if !partition.Session.Empty() { + if sessions, ok := p.assignmentIndex[partition.Session]; ok { + delete(sessions, partitionID) + if len(sessions) == 0 { + delete(p.assignmentIndex, partition.Session) + } + } + } + + // Remove from successor index + if !partition.Successor.Empty() { + if successors, ok := p.successorIndex[partition.Successor]; ok { + delete(successors, partitionID) + if len(successors) == 0 { + delete(p.successorIndex, partition.Successor) + } + } + } +} + +func (p *partitionCache) addPartition(partition Partition) { + // Add to main partition map + p.partitions[partition.ID] = partition + + // Add to session index and partition index for assigned session + if !partition.Session.Empty() { + if _, ok := p.assignmentIndex[partition.Session]; !ok { + p.assignmentIndex[partition.Session] = make(map[int64]struct{}) + } + p.assignmentIndex[partition.Session][partition.ID] = struct{}{} + } + + // Add to successor index for successor session + if !partition.Successor.Empty() { + if _, ok := p.successorIndex[partition.Successor]; !ok { + p.successorIndex[partition.Successor] = make(map[int64]struct{}) + } + p.successorIndex[partition.Successor][partition.ID] = struct{}{} + } +} + +func (p *partitionCache) updatePartition(oldPartition, newPartition Partition) { + // Update main partition map + p.partitions[newPartition.ID] = newPartition + + // Remove old session assignments + if !oldPartition.Session.Empty() { + if sessions, ok := p.assignmentIndex[oldPartition.Session]; ok { + delete(sessions, oldPartition.ID) + if len(sessions) == 0 { + delete(p.assignmentIndex, oldPartition.Session) + } + } + } + + // Add new session assignments + if !newPartition.Session.Empty() { + if _, ok := p.assignmentIndex[newPartition.Session]; !ok { + p.assignmentIndex[newPartition.Session] = make(map[int64]struct{}) + } + p.assignmentIndex[newPartition.Session][newPartition.ID] = struct{}{} + } + + // Remove old successor assignments + if !oldPartition.Successor.Empty() { + if successors, ok := p.successorIndex[oldPartition.Successor]; ok { + delete(successors, oldPartition.ID) + if len(successors) == 0 { + delete(p.successorIndex, oldPartition.Successor) + } + } + } + + // Add new successor assignments + if !newPartition.Successor.Empty() { + if _, ok := p.successorIndex[newPartition.Successor]; !ok { + p.successorIndex[newPartition.Successor] = make(map[int64]struct{}) + } + p.successorIndex[newPartition.Successor][newPartition.ID] = struct{}{} + } +} + +func (p *partitionCache) isStale(assignment *Assignment) bool { + cachedAssignment := p.assignmentIndex[assignment.Session] + if len(assignment.Partitions) != len(cachedAssignment) { + return true + } + for _, partition := range assignment.Partitions { + if _, ok := cachedAssignment[partition.ID]; !ok { + return true + } + } + return false +} + +func (p *partitionCache) constructAssignment(session Session) *Assignment { + assignment := &Assignment{ + Session: session, + Partitions: make([]Partition, 0, len(p.assignmentIndex[session])), + } + for partitionID := range p.assignmentIndex[session] { + assignment.Partitions = append(assignment.Partitions, p.partitions[partitionID]) + } + sort.Slice(assignment.Partitions, func(i, j int) bool { + return assignment.Partitions[i].ID < assignment.Partitions[j].ID + }) + return assignment +} + +func (p *partitionCache) planRegister( + session Session, cache partitionCache, +) (tryClaim Partition, trySteal Partition) { + // Check to see if there is an an unassigned partition that can be claimed. + for _, partition := range cache.partitions { + if partition.Session.Empty() { + return partition, Partition{} + } + } + maxPartitions := (len(p.partitions) + len(p.assignmentIndex) - 1) / len(p.assignmentIndex) + return Partition{}, p.planTheft(1, maxPartitions) +} + +func (p *partitionCache) planAssignment( + session Session, caughtUp bool, cache partitionCache, +) (tryRelease []Partition, tryClaim Partition, trySteal Partition) { + + for partitionId := range p.assignmentIndex[session] { + partition := p.partitions[partitionId] + if !partition.Successor.Empty() { + tryRelease = append(tryRelease, partition) + } + } + if len(tryRelease) != 0 { + return tryRelease, Partition{}, Partition{} + } + + // If we aren't caught up, we should not try to claim any new partitions. + if !caughtUp { + return nil, Partition{}, Partition{} + } + + // Check to see if there is an an unassigned partition that can be claimed. + for _, partition := range p.partitions { + // TODO(jeffswenson): we should really try to claim a random partition to + // avoid contention. + if partition.Session.Empty() { + return nil, partition, Partition{} + } + } + + if p.successorIndex[session] != nil { + // If the session is trying to steal already, do not steal another session. + return nil, Partition{}, Partition{} + } + + // maxPartitions is the maximum number of partitions we would expect to be + // assigned to this session. + maxPartitions := len(p.partitions) + if len(p.assignmentIndex) != 0 { + maxPartitions = (len(p.partitions) + len(p.assignmentIndex) - 1) / len(p.assignmentIndex) + } + assignedPartitions := len(p.assignmentIndex[session]) + if maxPartitions <= assignedPartitions { + return nil, Partition{}, Partition{} + } + + // NOTE: planTheft may return an empty partition. E.g. consider the case + // where there are two sessions and three partitions. In that case the + // maximum partition assignment is 2, but one partition will end up with + // only 1 assignment. It will consider stealing even though the partitions + // are balanced. + // + // We prioritize stealing sessions from any client that has more than the + // maximum expected number of partitions. But we are willing to steal from + // any client that has two more partitions than this client currently has. + // Stealing from someone with less than the maximum expected number of is + // needed to handle distributions like: + // a -> 3 partitions + // b -> 3 partitions + // c -> 1 partition + return nil, Partition{}, p.planTheft(assignedPartitions+1, maxPartitions) +} + +// planTheft selects a partition from a session that has more partitions +// assigned to it than the richTreshold. E.g. `richTreshold` is 1, any session +// with 2 or more partitions is a candidate for work stealing. +func (p *partitionCache) planTheft(minumExpected, maximumExpected int) Partition { + richCandidates, eligibleCandidates := []Partition{}, []Partition{} + for _, session := range p.assignmentIndex { + assignedPartitions := len(session) + if maximumExpected < assignedPartitions { + for partitionID := range session { + richCandidates = append(richCandidates, p.partitions[partitionID]) + } + } + if minumExpected < assignedPartitions { + for partitionID := range session { + eligibleCandidates = append(eligibleCandidates, p.partitions[partitionID]) + } + } + } + + if len(richCandidates) != 0 { + return richCandidates[rand.Intn(len(richCandidates))] + } + if len(eligibleCandidates) != 0 { + return eligibleCandidates[rand.Intn(len(eligibleCandidates))] + } + return Partition{} +} diff --git a/pkg/sql/queuefeed/partition_cache_test.go b/pkg/sql/queuefeed/partition_cache_test.go new file mode 100644 index 000000000000..4973e4f4d345 --- /dev/null +++ b/pkg/sql/queuefeed/partition_cache_test.go @@ -0,0 +1,211 @@ +package queuefeed + +import ( + "math/rand" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +type assignmentSimulator struct { + t *testing.T + sessions []Session + cache partitionCache +} + +func newAssignmentSimulator(t *testing.T, partitionCount int) *assignmentSimulator { + partitions := make([]Partition, partitionCount) + for i := range partitions { + partitions[i] = Partition{ + ID: int64(i + 1), + Session: Session{}, // Unassigned + Successor: Session{}, + } + } + sim := &assignmentSimulator{t: t} + sim.cache.Init(partitions) + return sim +} + +// refreshAssignment returns true if refreshing the assignment took any action. +func (a *assignmentSimulator) refreshAssignment(session Session) bool { + tryRelease, tryClaim, trySecede := a.cache.planAssignment(session, true, a.cache) + + updates := make(map[int64]Partition) + + // This is simulating the production implementation. The prod implementation + // would use a txn to apply these changes to the DB, then update the cache + // with the latest version of the rows. + for _, partition := range tryRelease { + updates[partition.ID] = Partition{ + ID: partition.ID, + Session: partition.Successor, + Successor: Session{}, + Span: partition.Span, + } + } + if !tryClaim.Empty() { + updates[tryClaim.ID] = Partition{ + ID: tryClaim.ID, + Session: session, + Successor: Session{}, + Span: tryClaim.Span, + } + } + if !trySecede.Empty() { + updates[trySecede.ID] = Partition{ + ID: trySecede.ID, + Session: trySecede.Session, + Successor: session, + Span: trySecede.Span, + } + } + + a.cache.Update(updates) + + return len(updates) != 0 +} + +func (a *assignmentSimulator) createSession() Session { + session := Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID(uuid.MakeV4().String()), + } + + a.sessions = append(a.sessions, session) + + // This is simulating the production implementation. The prod implementation + // would use a txn to apply these changes to the DB, then update the cache + // with the latest version of the rows. + tryClaim, trySecede := a.cache.planRegister(session, a.cache) + + updates := make(map[int64]Partition) + if !tryClaim.Empty() { + tryClaim.Session = session + updates[tryClaim.ID] = tryClaim + } + if !trySecede.Empty() { + trySecede.Successor = session + updates[trySecede.ID] = trySecede + } + a.cache.Update(updates) + + return session +} + +func (a *assignmentSimulator) removeSession(session Session) { + // Remove session from sessions list + for i, s := range a.sessions { + if s == session { + a.sessions = append(a.sessions[:i], a.sessions[i+1:]...) + break + } + } + + assignment := a.cache.constructAssignment(session) + updates := make(map[int64]Partition) + + for _, partition := range assignment.Partitions { + // For each partition assigned to this session, make the successor (if any) + // the owner. + updates[partition.ID] = Partition{ + ID: partition.ID, + Session: partition.Successor, + Successor: Session{}, // Clear successor + Span: partition.Span, + } + } + // Any partitions where this session is the successor should have + // the successor cleared. + for id := range a.cache.partitions { + if a.cache.partitions[id].Successor != session { + continue + } + partition := a.cache.partitions[id] + updates[partition.ID] = Partition{ + ID: partition.ID, + Session: partition.Session, // Keep current owner + Successor: Session{}, // Clear successor + Span: partition.Span, + } + } + + a.cache.Update(updates) +} + +func (a *assignmentSimulator) runToStable() { + maxIterations := 100000 // Prevent infinite loops + + for i := 0; i < maxIterations; i++ { + actionTaken := false + + // Process each session + for _, session := range a.sessions { + if a.refreshAssignment(session) { + actionTaken = true + } + } + + // If no action was taken in this round, we're stable + if !actionTaken { + return + } + } + + a.t.Fatalf("runToStable exceeded maximum iterations (%d sessions, %d partitions): %s ", len(a.sessions), len(a.cache.partitions), a.cache.DebugString()) +} + +func TestPartitionCacheSimple(t *testing.T) { + sim := newAssignmentSimulator(t, 2) + + // Create two sessions. + session1 := sim.createSession() + sim.runToStable() + session2 := sim.createSession() + sim.runToStable() + + // Each session should have one partition. + assignment1 := sim.cache.constructAssignment(session1) + assignment2 := sim.cache.constructAssignment(session2) + require.Len(t, assignment1.Partitions, 1) + require.Len(t, assignment2.Partitions, 1) + + // After removing one session, the other session should have both partitions. + sim.removeSession(session1) + sim.runToStable() + assignment2 = sim.cache.constructAssignment(session2) + require.Len(t, assignment2.Partitions, 2) +} + +func TestPartitionCacheRandom(t *testing.T) { + partitions := rand.Intn(1000) + 1 + sessions := make([]Session, rand.Intn(100)+1) + + sim := newAssignmentSimulator(t, partitions) + + for i := range sessions { + if rand.Int()%2 == 0 { + sim.runToStable() + } + sessions[i] = sim.createSession() + } + sim.runToStable() + + t.Logf("%d partitions, %d sessions", partitions, len(sessions)) + t.Log(sim.cache.DebugString()) + + // Verify all partitions are assigned + for _, partition := range sim.cache.partitions { + require.False(t, partition.Session.Empty()) + } + + // Verify that no session has more than ceil(partitions / len(sessions)) + // partitions. + maxPerSession := (partitions + len(sessions) - 1) / len(sessions) + for _, session := range sessions { + assignment := sim.cache.constructAssignment(session) + require.LessOrEqual(t, len(assignment.Partitions), maxPerSession) + } +} diff --git a/pkg/sql/queuefeed/partitions.go b/pkg/sql/queuefeed/partitions.go new file mode 100644 index 000000000000..8f2b24afd4b8 --- /dev/null +++ b/pkg/sql/queuefeed/partitions.go @@ -0,0 +1,347 @@ +package queuefeed + +import ( + "context" + "fmt" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/cockroachdb/errors" +) + +type Partition struct { + // ID is the `partition_id` column in the queue partition table. + ID int64 + // Session is the `user_session` and `sql_liveness_session` assigned to this + // partition. + Session Session + // Successor is the `user_session_successor` and + // `sql_liveness_session_successor` assigned to the partition. + Successor Session + // Span is decoded from the `partition_spec` column. + Span roachpb.Span +} + +func PartitionFromDatums(row tree.Datums) (Partition, error) { + var session, successor Session + if !(row[1] == tree.DNull || row[2] == tree.DNull) { + session = Session{ + LivenessID: sqlliveness.SessionID(tree.MustBeDBytes(row[1])), + ConnectionID: tree.MustBeDUuid(row[2]).UUID, + } + } + if !(row[3] == tree.DNull || row[4] == tree.DNull) { + successor = Session{ + LivenessID: sqlliveness.SessionID(tree.MustBeDBytes(row[3])), + ConnectionID: tree.MustBeDUuid(row[4]).UUID, + } + } + + var span roachpb.Span + if row[5] != tree.DNull { + var err error + span, err = decodeSpan([]byte(*row[5].(*tree.DBytes))) + if err != nil { + return Partition{}, err + } + } + + return Partition{ + ID: int64(tree.MustBeDInt(row[0])), + Session: session, + Successor: successor, + Span: span, + }, nil +} + +type partitionTable struct { + queueName string +} + +func (p *partitionTable) CreateSchema(ctx context.Context, txn isql.Txn) error { + _, err := txn.Exec(ctx, "create-partition-table", txn.KV(), + fmt.Sprintf(`CREATE TABLE IF NOT EXISTS defaultdb.queue_partition_%s ( + partition_id BIGSERIAL PRIMARY KEY, + sql_liveness_session BYTES, + user_session UUID, + sql_liveness_session_successor BYTES, + user_session_successor UUID, + partition_spec BYTES + )`, p.queueName)) + return err +} + +func (p *partitionTable) ListPartitions(ctx context.Context, txn isql.Txn) ([]Partition, error) { + rows, err := txn.QueryBuffered(ctx, "list-partitions", txn.KV(), fmt.Sprintf(` + SELECT + partition_id, + sql_liveness_session, + user_session, + sql_liveness_session_successor, + user_session_successor, + partition_spec + FROM defaultdb.queue_partition_%s`, p.queueName)) + if err != nil { + return nil, err + } + partitions := make([]Partition, len(rows)) + for i, row := range rows { + var err error + partitions[i], err = PartitionFromDatums(row) + if err != nil { + return nil, err + } + } + return partitions, nil +} + +// FetchPartitions fetches all of the partitions with the given IDs. The len of +// the returned map is eqaul to the number of unique partitionIDs passed in. If +// a partition id is not found, it will be present in the map with a zero-value +// Partition. +func (p *partitionTable) FetchPartitions( + ctx context.Context, txn isql.Txn, partitionIDs []int64, +) (map[int64]Partition, error) { + if len(partitionIDs) == 0 { + return make(map[int64]Partition), nil + } + + // Initialize result map with zero-value partitions for all unique IDs + result := make(map[int64]Partition) + for _, id := range partitionIDs { + result[id] = Partition{} // Zero-value partition as placeholder + } + + datumArray := tree.NewDArray(types.Int) + for _, id := range partitionIDs { + if err := datumArray.Append(tree.NewDInt(tree.DInt(id))); err != nil { + return nil, err + } + } + + rows, err := txn.QueryBuffered(ctx, "fetch-partitions", txn.KV(), fmt.Sprintf(` + SELECT + partition_id, + sql_liveness_session, + user_session, + sql_liveness_session_successor, + user_session_successor, + partition_spec + FROM defaultdb.queue_partition_%s + WHERE partition_id = ANY($1)`, p.queueName), datumArray) + if err != nil { + return nil, err + } + + // Process found partitions + for _, row := range rows { + partition, err := PartitionFromDatums(row) + if err != nil { + return nil, err + } + result[partition.ID] = partition + } + + return result, nil +} + +// Get retrieves a single partition by ID. Returns an error if the partition +// is not found. +func (p *partitionTable) Get( + ctx context.Context, txn isql.Txn, partitionID int64, +) (Partition, error) { + row, err := txn.QueryRow(ctx, "get-partition", txn.KV(), + fmt.Sprintf(` + SELECT + partition_id, + sql_liveness_session, + user_session, + sql_liveness_session_successor, + user_session_successor, + partition_spec + FROM defaultdb.queue_partition_%s + WHERE partition_id = $1`, p.queueName), partitionID) + if err != nil { + return Partition{}, err + } + + if row == nil { + return Partition{}, errors.Newf("no partition found with id %d", partitionID) + } + + partition, err := PartitionFromDatums(row) + if err != nil { + return Partition{}, err + } + + return partition, nil +} + +func (p *partitionTable) InsertPartition( + ctx context.Context, txn isql.Txn, partition Partition, +) error { + var sessionLivenessID, sessionConnectionID interface{} + var successorLivenessID, successorConnectionID interface{} + + if !partition.Session.Empty() { + sessionLivenessID = []byte(partition.Session.LivenessID) + sessionConnectionID = partition.Session.ConnectionID + } else { + sessionLivenessID = nil + sessionConnectionID = nil + } + + if !partition.Successor.Empty() { + successorLivenessID = []byte(partition.Successor.LivenessID) + successorConnectionID = partition.Successor.ConnectionID + } else { + successorLivenessID = nil + successorConnectionID = nil + } + + spanBytes := encodeSpan(partition.Span) + + _, err := txn.Exec(ctx, "insert-partition", txn.KV(), + fmt.Sprintf(`INSERT INTO defaultdb.queue_partition_%s + (partition_id, sql_liveness_session, user_session, sql_liveness_session_successor, user_session_successor, partition_spec) + VALUES ($1, $2, $3, $4, $5, $6)`, p.queueName), + partition.ID, sessionLivenessID, sessionConnectionID, + successorLivenessID, successorConnectionID, spanBytes) + + return err +} + +func (p *partitionTable) UpdatePartition( + ctx context.Context, txn isql.Txn, partition Partition, +) error { + var sessionLivenessID, sessionConnectionID interface{} + var successorLivenessID, successorConnectionID interface{} + + if !partition.Session.Empty() { + sessionLivenessID = []byte(partition.Session.LivenessID) + sessionConnectionID = partition.Session.ConnectionID + } else { + sessionLivenessID = nil + sessionConnectionID = nil + } + + if !partition.Successor.Empty() { + successorLivenessID = []byte(partition.Successor.LivenessID) + successorConnectionID = partition.Successor.ConnectionID + } else { + successorLivenessID = nil + successorConnectionID = nil + } + + spanBytes := encodeSpan(partition.Span) + + _, err := txn.Exec(ctx, "update-partition", txn.KV(), + fmt.Sprintf(`UPDATE defaultdb.queue_partition_%s + SET sql_liveness_session = $2, + user_session = $3, + sql_liveness_session_successor = $4, + user_session_successor = $5, + partition_spec = $6 + WHERE partition_id = $1`, p.queueName), + partition.ID, sessionLivenessID, sessionConnectionID, + successorLivenessID, successorConnectionID, spanBytes) + + return err +} + +// UnregisterSession removes the given session from all assignments and +// partition claims, it returns the updated partitions. +func (p *partitionTable) UnregisterSession( + ctx context.Context, txn isql.Txn, session Session, +) (updates map[int64]Partition, err error) { + sessionLivenessID := []byte(session.LivenessID) + sessionConnectionID := session.ConnectionID + + rows, err := txn.QueryBuffered(ctx, "unregister-session", txn.KV(), fmt.Sprintf(` + UPDATE defaultdb.queue_partition_%s + SET + sql_liveness_session = CASE + WHEN sql_liveness_session = $1 AND user_session = $2 THEN sql_liveness_session_successor + ELSE sql_liveness_session + END, + user_session = CASE + WHEN sql_liveness_session = $1 AND user_session = $2 THEN user_session_successor + ELSE user_session + END, + sql_liveness_session_successor = CASE + WHEN sql_liveness_session = $1 AND user_session = $2 THEN NULL + WHEN sql_liveness_session_successor = $1 AND user_session_successor = $2 THEN NULL + ELSE sql_liveness_session_successor + END, + user_session_successor = CASE + WHEN sql_liveness_session = $1 AND user_session = $2 THEN NULL + WHEN sql_liveness_session_successor = $1 AND user_session_successor = $2 THEN NULL + ELSE user_session_successor + END + WHERE (sql_liveness_session = $1 AND user_session = $2) + OR (sql_liveness_session_successor = $1 AND user_session_successor = $2) + RETURNING partition_id, sql_liveness_session, user_session, + sql_liveness_session_successor, user_session_successor, partition_spec`, p.queueName), sessionLivenessID, sessionConnectionID) + if err != nil { + return nil, err + } + + updates = make(map[int64]Partition) + for _, row := range rows { + partition, err := PartitionFromDatums(row) + if err != nil { + return nil, err + } + updates[partition.ID] = partition + } + + return updates, nil +} + +func (p Partition) Empty() bool { + return p.ID == 0 +} + +// Equal returns true if two partitions are equal in all fields. +func (p Partition) Equal(other Partition) bool { + return p.ID == other.ID && + p.Session == other.Session && + p.Successor == other.Successor && + p.Span.Equal(other.Span) +} + +type Session struct { + // ConnectionID is the ID of the underlying connection. + ConnectionID uuid.UUID + // LivenessID is the session ID for the server. Its used to identify sessions + // that belong to dead sql servers. + LivenessID sqlliveness.SessionID +} + +func (s Session) Empty() bool { + return s.ConnectionID == uuid.Nil && s.LivenessID == "" +} + +func decodeSpan(data []byte) (roachpb.Span, error) { + var span roachpb.Span + if err := span.Unmarshal(data); err != nil { + return roachpb.Span{}, err + } + return span, nil +} + +func encodeSpan(span roachpb.Span) []byte { + data, err := span.Marshal() + if err != nil { + return nil + } + return data +} + +func TestNewPartitionsTable(queueName string) *partitionTable { + return &partitionTable{queueName: queueName} +} diff --git a/pkg/sql/queuefeed/partitions_test.go b/pkg/sql/queuefeed/partitions_test.go new file mode 100644 index 000000000000..9dffd1fa989c --- /dev/null +++ b/pkg/sql/queuefeed/partitions_test.go @@ -0,0 +1,306 @@ +package queuefeed + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +func TestListPartitions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + sqlRunner := sqlutils.MakeSQLRunner(sqlDB) + queueName := "test" + + pt := &partitionTable{queueName: queueName} + + // Create table + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Test empty + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partitions, err := pt.ListPartitions(ctx, txn) + require.NoError(t, err) + require.Empty(t, partitions) + return nil + }) + require.NoError(t, err) + + // Insert one partition + sessionID := uuid.MakeV4() + connectionID := uuid.MakeV4() + span := &roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("z")} + spanBytes, _ := span.Marshal() + + sqlRunner.Exec(t, ` + INSERT INTO defaultdb.queue_partition_`+queueName+` + (partition_id, sql_liveness_session, user_session, partition_spec) + VALUES (1, $1, $2, $3)`, sessionID, connectionID, spanBytes) + + // Test with data + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partitions, err := pt.ListPartitions(ctx, txn) + require.NoError(t, err) + require.Len(t, partitions, 1) + require.Equal(t, int64(1), partitions[0].ID) + require.Equal(t, connectionID, partitions[0].Session.ConnectionID) + return nil + }) + require.NoError(t, err) +} + +func TestUpdatePartition(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + sqlRunner := sqlutils.MakeSQLRunner(sqlDB) + queueName := "test" + + pt := &partitionTable{queueName: queueName} + + // Create table + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Insert initial partition + sqlRunner.Exec(t, ` + INSERT INTO defaultdb.queue_partition_`+queueName+` (partition_id) VALUES (1)`) + + // Update the partition + newSessionID := uuid.MakeV4() + newConnectionID := uuid.MakeV4() + span := roachpb.Span{Key: roachpb.Key("new"), EndKey: roachpb.Key("span")} + + partition := Partition{ + ID: 1, + Session: Session{ + LivenessID: sqlliveness.SessionID(newSessionID.GetBytes()), + ConnectionID: newConnectionID, + }, + Span: span, + } + + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.UpdatePartition(ctx, txn, partition) + }) + require.NoError(t, err) + + // Verify update worked + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partitions, err := pt.ListPartitions(ctx, txn) + require.NoError(t, err) + require.Len(t, partitions, 1) + require.Equal(t, newConnectionID, partitions[0].Session.ConnectionID) + require.Equal(t, span.Key, partitions[0].Span.Key) + return nil + }) + require.NoError(t, err) +} + +func TestInsertPartition(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + queueName := "test" + + pt := &partitionTable{queueName: queueName} + + // Create table + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Insert partition + sessionID := uuid.MakeV4() + connectionID := uuid.MakeV4() + span := roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("z")} + + partition := Partition{ + ID: 1, + Session: Session{ + LivenessID: sqlliveness.SessionID(sessionID.GetBytes()), + ConnectionID: connectionID, + }, + Span: span, + } + + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.InsertPartition(ctx, txn, partition) + }) + require.NoError(t, err) + + // Verify insertion worked + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partitions, err := pt.ListPartitions(ctx, txn) + require.NoError(t, err) + require.Len(t, partitions, 1) + require.Equal(t, int64(1), partitions[0].ID) + require.Equal(t, connectionID, partitions[0].Session.ConnectionID) + require.Equal(t, span.Key, partitions[0].Span.Key) + require.Equal(t, span.EndKey, partitions[0].Span.EndKey) + return nil + }) + require.NoError(t, err) +} + +func TestFetchPartitions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + sqlRunner := sqlutils.MakeSQLRunner(sqlDB) + queueName := "test" + pt := &partitionTable{queueName: queueName} + + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Insert some test data + sqlRunner.Exec(t, `INSERT INTO defaultdb.queue_partition_test (partition_id) VALUES (1), (3)`) + + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + result, err := pt.FetchPartitions(ctx, txn, []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, int64(1), result[1].ID) + require.True(t, result[2].Empty()) + require.Equal(t, int64(3), result[3].ID) + return nil + }) + require.NoError(t, err) +} + +func TestGetPartition(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + sqlRunner := sqlutils.MakeSQLRunner(sqlDB) + pt := &partitionTable{queueName: "test"} + + // Create table + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Test not found + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + _, err := pt.Get(ctx, txn, 999) + require.Error(t, err) + return nil + }) + require.NoError(t, err) + + // Insert test partition + connectionID := uuid.MakeV4() + span := roachpb.Span{Key: roachpb.Key("test"), EndKey: roachpb.Key("testend")} + spanBytes, _ := span.Marshal() + sqlRunner.Exec(t, `INSERT INTO defaultdb.queue_partition_test (partition_id, sql_liveness_session, user_session, partition_spec) VALUES (1, $1, $2, $3)`, []byte("test-session"), connectionID, spanBytes) + + // Test found + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + partition, err := pt.Get(ctx, txn, 1) + require.NoError(t, err) + require.Equal(t, int64(1), partition.ID) + require.Equal(t, connectionID, partition.Session.ConnectionID) + return nil + }) + require.NoError(t, err) +} + +func TestUnregisterSession(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := srv.ApplicationLayer().InternalDB().(isql.DB) + sqlRunner := sqlutils.MakeSQLRunner(sqlDB) + pt := &partitionTable{queueName: "test"} + + // Create table + err := db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return pt.CreateSchema(ctx, txn) + }) + require.NoError(t, err) + + // Create sessions + session1 := Session{LivenessID: "session1", ConnectionID: uuid.MakeV4()} + session2 := Session{LivenessID: "session2", ConnectionID: uuid.MakeV4()} + + // Insert partition with session1 as owner, session2 as successor + span1 := roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("b")} + span1Bytes, _ := span1.Marshal() + sqlRunner.Exec(t, `INSERT INTO defaultdb.queue_partition_test (partition_id, sql_liveness_session, user_session, sql_liveness_session_successor, user_session_successor, partition_spec) VALUES (1, $1, $2, $3, $4, $5)`, + []byte(session1.LivenessID), session1.ConnectionID, []byte(session2.LivenessID), session2.ConnectionID, span1Bytes) + + // Insert partition with session1 as owner, no successor + span2 := roachpb.Span{Key: roachpb.Key("c"), EndKey: roachpb.Key("d")} + span2Bytes, _ := span2.Marshal() + sqlRunner.Exec(t, `INSERT INTO defaultdb.queue_partition_test (partition_id, sql_liveness_session, user_session, partition_spec) VALUES (2, $1, $2, $3)`, + []byte(session1.LivenessID), session1.ConnectionID, span2Bytes) + + // Unregister session1 + err = db.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + updates, err := pt.UnregisterSession(ctx, txn, session1) + require.NoError(t, err) + require.Len(t, updates, 2) + + // Partition 1: session2 should now be the owner + partition1 := updates[1] + require.Equal(t, session2.ConnectionID, partition1.Session.ConnectionID) + require.True(t, partition1.Successor.Empty()) + + // Partition 2: should be unassigned (no successor to promote) + partition2 := updates[2] + require.True(t, partition2.Session.Empty()) + require.True(t, partition2.Successor.Empty()) + return nil + }) + require.NoError(t, err) +} diff --git a/pkg/sql/queuefeed/queuebase/BUILD.bazel b/pkg/sql/queuefeed/queuebase/BUILD.bazel new file mode 100644 index 000000000000..afcb02c98b49 --- /dev/null +++ b/pkg/sql/queuefeed/queuebase/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "queuebase", + srcs = ["queuebase.go"], + importpath = "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/sem/tree", + "//pkg/util/hlc", + ], +) diff --git a/pkg/sql/queuefeed/queuebase/queuebase.go b/pkg/sql/queuefeed/queuebase/queuebase.go new file mode 100644 index 000000000000..da319c0f51b7 --- /dev/null +++ b/pkg/sql/queuefeed/queuebase/queuebase.go @@ -0,0 +1,25 @@ +package queuebase + +import ( + "context" + + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/hlc" +) + +type Manager interface { + CreateQueue(ctx context.Context, name string, tableID int64) error + CreateQueueFromCursor(ctx context.Context, name string, tableID int64, cursor hlc.Timestamp) error +} + +// Implemented by the conn executor in reality +type ReaderProvider interface { + GetOrInitReader(ctx context.Context, name string) (Reader, error) +} + +type Reader interface { + GetRows(ctx context.Context, limit int) ([]tree.Datums, error) + ConfirmReceipt(ctx context.Context) + RollbackBatch(ctx context.Context) + Close() error +} diff --git a/pkg/sql/queuefeed/reader.go b/pkg/sql/queuefeed/reader.go new file mode 100644 index 000000000000..35c0d360c765 --- /dev/null +++ b/pkg/sql/queuefeed/reader.go @@ -0,0 +1,629 @@ +package queuefeed + +import ( + "context" + "fmt" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/changefeedccl/changefeedbase" + "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/kv/kvclient/rangefeed" + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/descpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/fetchpb" + "github.com/cockroachdb/cockroach/pkg/sql/catalog/lease" + "github.com/cockroachdb/cockroach/pkg/sql/isql" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" + "github.com/cockroachdb/cockroach/pkg/sql/row" + "github.com/cockroachdb/cockroach/pkg/sql/rowenc" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/span" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" +) + +const maxBufSize = 1000 + +type readerState int + +const ( + readerStateBatching readerState = iota + readerStateHasUncommittedBatch + readerStateCheckingForReassignment + readerStateDead +) + +// bufferedEvent represents either a data row or a checkpoint timestamp +// in the reader's buffer. Exactly one of row or resolved will be set. +type bufferedEvent struct { + // row is set for data events. nil for checkpoint events. + row tree.Datums + // resolved is set for checkpoint events. Empty for data events. + resolved hlc.Timestamp +} + +// has rangefeed on data. reads from it. handles handoff +// state machine around handing out batches and handing stuff off +type Reader struct { + executor isql.DB + rff *rangefeed.Factory + mgr *Manager + name string + assigner *PartitionAssignments + + // stuff for decoding data. this is ripped from rowfetcher_cache.go in changefeeds + codec keys.SQLCodec + leaseMgr *lease.Manager + + mu struct { + syncutil.Mutex + state readerState + buf []bufferedEvent + inflightBuffer []bufferedEvent + poppedWakeup *sync.Cond + pushedWakeup *sync.Cond + } + + // TODO: handle the case where an assignment can change. + session Session + assignment *Assignment + rangefeed *rangefeed.RangeFeed + + cancel context.CancelCauseFunc + goroCtx context.Context + isShutdown atomic.Bool +} + +func NewReader( + ctx context.Context, + executor isql.DB, + mgr *Manager, + rff *rangefeed.Factory, + codec keys.SQLCodec, + leaseMgr *lease.Manager, + session Session, + assigner *PartitionAssignments, + name string, +) (*Reader, error) { + r := &Reader{ + executor: executor, + mgr: mgr, + codec: codec, + leaseMgr: leaseMgr, + name: name, + rff: rff, + // stored so we can use it in methods using a different context than the main goro ie GetRows and ConfirmReceipt + goroCtx: ctx, + assigner: assigner, + session: session, + } + r.mu.state = readerStateBatching + r.mu.buf = make([]bufferedEvent, 0, maxBufSize) + r.mu.poppedWakeup = sync.NewCond(&r.mu.Mutex) + r.mu.pushedWakeup = sync.NewCond(&r.mu.Mutex) + + ctx, cancel := context.WithCancelCause(ctx) + r.cancel = func(cause error) { + fmt.Printf("canceling with cause: %+v\n", cause) + cancel(cause) + r.mu.poppedWakeup.Broadcast() + r.mu.pushedWakeup.Broadcast() + } + + assignment, err := r.waitForAssignment(ctx, session) + if err != nil { + return nil, errors.Wrap(err, "waiting for assignment") + } + if err := r.setupRangefeed(ctx, assignment); err != nil { + return nil, errors.Wrap(err, "setting up rangefeed") + } + go r.run(ctx) + return r, nil +} + +var ErrNoPartitionsAssigned = errors.New("no partitions assigned to reader: todo support this case by polling for assignment") + +func (r *Reader) waitForAssignment(ctx context.Context, session Session) (*Assignment, error) { + // We can rapidly poll this because the assigner has an in-memory cache of + // assignments. + // + // TODO: should this retry loop be in RegisterSession instead? + timer := time.NewTicker(100 * time.Millisecond) + defer timer.Stop() + for { + assignment, err := r.assigner.RegisterSession(ctx, session) + if err != nil { + return nil, errors.Wrap(err, "registering session for reader") + } + if len(assignment.Partitions) != 0 { + return assignment, nil + } + + select { + case <-ctx.Done(): + return nil, errors.Wrap(ctx.Err(), "waiting for assignment") + case <-timer.C: + // continue + } + } +} + +func (r *Reader) setupRangefeed(ctx context.Context, assignment *Assignment) error { + defer func() { + fmt.Println("setupRangefeed done") + }() + + // TODO: handle the case where there are no partitions in the assignment. In + // that case we should poll `RefreshAssignment` until we get one. This would + // only occur if every assignment was handed out already. + if len(assignment.Partitions) == 0 { + return errors.Wrap(ErrNoPartitionsAssigned, "setting up rangefeed") + } + + onValue := func(ctx context.Context, value *kvpb.RangeFeedValue) { + fmt.Printf("onValue: %+v\n", value) + r.mu.Lock() + defer r.mu.Unlock() + + // wait for rows to be read before adding more, if necessary + for ctx.Err() == nil && len(r.mu.buf) > maxBufSize { + r.mu.poppedWakeup.Wait() + } + + if !value.Value.IsPresent() { + // not handling diffs/deletes rn + return + } + datums, err := r.decodeRangefeedValue(ctx, value) + if err != nil { + r.cancel(errors.Wrapf(err, "decoding rangefeed value: %+v", value)) + return + } + r.mu.buf = append(r.mu.buf, bufferedEvent{row: datums}) + r.mu.pushedWakeup.Broadcast() + fmt.Printf("onValue done with buf len: %d\n", len(r.mu.buf)) + } + // setup rangefeed on data + opts := []rangefeed.Option{ + rangefeed.WithPProfLabel("queuefeed.reader", fmt.Sprintf("name=%s", r.name)), + // rangefeed.WithMemoryMonitor(w.mon), + rangefeed.WithOnCheckpoint(func(ctx context.Context, checkpoint *kvpb.RangeFeedCheckpoint) { + // This can happen when done catching up; ignore it. + if checkpoint.ResolvedTS.IsEmpty() { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Wait for rows to be read before adding more, if necessary. + for ctx.Err() == nil && len(r.mu.buf) > maxBufSize { + r.mu.poppedWakeup.Wait() + } + + if ctx.Err() != nil { + return + } + + r.mu.buf = append(r.mu.buf, bufferedEvent{resolved: checkpoint.ResolvedTS}) + }), + rangefeed.WithOnInternalError(func(ctx context.Context, err error) { r.cancel(err) }), + rangefeed.WithConsumerID(42), + rangefeed.WithInvoker(func(fn func() error) error { return fn() }), + rangefeed.WithFiltering(false), + } + + frontier, err := span.MakeFrontier(assignment.Spans()...) + if err != nil { + return errors.Wrap(err, "creating frontier") + } + + for _, partition := range assignment.Partitions { + checkpointTS, err := r.mgr.ReadCheckpoint(ctx, r.name, partition.ID) + if err != nil { + return errors.Wrapf(err, "reading checkpoint for partition %d", partition.ID) + } + if !checkpointTS.IsEmpty() { + _, err := frontier.Forward(partition.Span, checkpointTS) + if err != nil { + return errors.Wrapf(err, "advancing frontier for partition %d to checkpoint %s", partition.ID, checkpointTS) + } + } else { + return errors.Errorf("checkpoint is empty for partition %d", partition.ID) + } + } + + if frontier.Frontier().IsEmpty() { + return errors.New("frontier is empty") + } + + rf := r.rff.New( + fmt.Sprintf("queuefeed.reader.name=%s", r.name), frontier.Frontier(), onValue, opts..., + ) + + if err := rf.StartFromFrontier(ctx, frontier); err != nil { + return errors.Wrap(err, "starting rangefeed") + } + + r.rangefeed = rf + r.assignment = assignment + return nil +} + +// - [x] setup rangefeed on data +// - [X] handle only watching my partitions +// - [X] after each batch, ask mgr if i need to assignments +// - [X] buffer rows in the background before being asked for them +// - [ ] checkpoint frontier if our frontier has advanced and we confirmed receipt +// - [X] gonna need some way to clean stuff up on conn_executor.close() + +// TODO: this loop isnt doing much anymore. if we dont need it for anything else, let's remove it +func (r *Reader) run(ctx context.Context) { + defer func() { + fmt.Println("run done") + r.isShutdown.Store(true) + }() + + for { + select { + case <-ctx.Done(): + fmt.Printf("run: ctx done: %s; cause: %s\n", ctx.Err(), context.Cause(ctx)) + return + } + } +} + +func (r *Reader) GetRows(ctx context.Context, limit int) ([]tree.Datums, error) { + fmt.Printf("GetRows start\n") + + if r.isShutdown.Load() { + return nil, errors.New("reader is shutting down") + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.state != readerStateBatching { + return nil, errors.New("reader not idle") + } + if len(r.mu.inflightBuffer) > 0 { + return nil, errors.AssertionFailedf("getrows called with nonempty inflight buffer") + } + + // Helper to count data events (not checkpoints) in buffer + hasDataEvents := func() bool { + for _, event := range r.mu.buf { + if event.resolved.IsEmpty() { + return true + } + } + return false + } + + // Wait until we have at least one data event (not just checkpoints) + if !hasDataEvents() { + // shut down the reader if this ctx (which is distinct from the goro ctx) is canceled + defer context.AfterFunc(ctx, func() { + r.cancel(errors.Wrapf(ctx.Err(), "GetRows canceled")) + })() + for ctx.Err() == nil && r.goroCtx.Err() == nil && !hasDataEvents() { + r.mu.pushedWakeup.Wait() + } + if ctx.Err() != nil { + return nil, errors.Wrapf(ctx.Err(), "GetRows canceled") + } + } + + // Find the position of the (limit+1)th data event (not checkpoint) + // We'll take everything up to that point, which gives us up to `limit` data rows + // plus any checkpoints that came before/between them. + bufferEndIdx := len(r.mu.buf) + + // Optimization: if the entire buffer is smaller than limit, take it all + if len(r.mu.buf) > limit { + dataCount := 0 + for i, event := range r.mu.buf { + if event.resolved.IsEmpty() { + dataCount++ + if dataCount > limit { + bufferEndIdx = i + break + } + } + } + } + + r.mu.inflightBuffer = append(r.mu.inflightBuffer, r.mu.buf[0:bufferEndIdx]...) + r.mu.buf = r.mu.buf[bufferEndIdx:] + + r.mu.state = readerStateHasUncommittedBatch + r.mu.poppedWakeup.Broadcast() + + // Here we filter to return only data events to the user. + result := make([]tree.Datums, 0, limit) + for _, event := range r.mu.inflightBuffer { + if event.resolved.IsEmpty() { + result = append(result, event.row) + } + } + + return result, nil +} + +// ConfirmReceipt is called when we commit a transaction that reads from the queue. +// We will checkpoint if we have checkpoint events in our inflightBuffer. +func (r *Reader) ConfirmReceipt(ctx context.Context) { + if r.isShutdown.Load() { + return + } + + var checkpointToWrite hlc.Timestamp + func() { + r.mu.Lock() + defer r.mu.Unlock() + + // Find the last checkpoint in inflightBuffer + for _, event := range r.mu.inflightBuffer { + if !event.resolved.IsEmpty() { + checkpointToWrite = event.resolved + } + } + + r.mu.inflightBuffer = r.mu.inflightBuffer[:0] + r.mu.state = readerStateCheckingForReassignment + }() + + // Persist the checkpoint if we have one. + if !checkpointToWrite.IsEmpty() { + for _, partition := range r.assignment.Partitions { + if err := r.mgr.WriteCheckpoint(ctx, r.name, partition.ID, checkpointToWrite); err != nil { + fmt.Printf("error writing checkpoint for partition %d: %s\n", partition.ID, err) + // TODO: decide how to handle checkpoint write errors. Since the txn + // has already committed, I don't think we can really fail at this point. + } + } + } + + select { + case <-ctx.Done(): + return + case <-r.goroCtx.Done(): + return + default: + // TODO only set caughtUp to true if our frontier is near the current time. + newAssignment, err := r.assigner.RefreshAssignment(ctx, r.assignment, true /*=caughtUp*/) + if err != nil { + r.cancel(errors.Wrap(err, "refreshing assignment")) + return + } + if newAssignment != nil { + if err := r.updateAssignment(newAssignment); err != nil { + r.cancel(errors.Wrap(err, "updating assignment")) + return + } + } + } + func() { + r.mu.Lock() + defer r.mu.Unlock() + r.mu.state = readerStateBatching + }() +} + +func (r *Reader) RollbackBatch(ctx context.Context) { + if r.isShutdown.Load() { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + newBuf := make([]bufferedEvent, 0, len(r.mu.inflightBuffer)+len(r.mu.buf)) + newBuf = append(newBuf, r.mu.inflightBuffer...) + newBuf = append(newBuf, r.mu.buf...) + r.mu.buf = newBuf + r.mu.inflightBuffer = r.mu.inflightBuffer[:0] + + r.mu.state = readerStateBatching +} + +func (r *Reader) IsAlive() bool { + return !r.isShutdown.Load() +} + +func (r *Reader) Close() error { + err := r.assigner.UnregisterSession(r.goroCtx, r.session) + r.cancel(errors.New("reader closing")) + r.rangefeed.Close() + return err +} + +func (r *Reader) updateAssignment(assignment *Assignment) error { + defer func() { + fmt.Printf("updateAssignment done with assignment: %+v\n", assignment) + }() + + r.rangefeed.Close() + r.assignment = assignment + + func() { + r.mu.Lock() + defer r.mu.Unlock() + r.mu.buf = r.mu.buf[:0] + }() + + if err := r.setupRangefeed(r.goroCtx, assignment); err != nil { + return errors.Wrapf(err, "setting up rangefeed for new assignment: %+v", assignment) + } + return nil +} + +func (r *Reader) checkForReassignment(ctx context.Context) error { + defer func() { + fmt.Println("checkForReassignment done") + }() + + r.mu.Lock() + defer r.mu.Unlock() + + if r.mu.state != readerStateCheckingForReassignment { + return errors.AssertionFailedf("reader not in checking for reassignment state") + } + + change, err := r.mgr.reassessAssignments(ctx, r.name) + if err != nil { + return errors.Wrap(err, "reassessing assignments") + } + if change { + fmt.Println("TODO: reassignment detected. lets do something about it") + } + r.mu.state = readerStateBatching + return nil +} + +func (r *Reader) decodeRangefeedValue( + ctx context.Context, rfv *kvpb.RangeFeedValue, +) (tree.Datums, error) { + partialKey := rfv.Key + partialKey, err := r.codec.StripTenantPrefix(partialKey) + if err != nil { + return nil, errors.Wrapf(err, "stripping tenant prefix: %s", keys.PrettyPrint(nil, partialKey)) + } + + _, tableID, _, err := rowenc.DecodePartialTableIDIndexID(partialKey) + if err != nil { + return nil, errors.Wrapf(err, "decoding partial table id index id: %s", keys.PrettyPrint(nil, partialKey)) + } + + familyID, err := keys.DecodeFamilyKey(partialKey) + if err != nil { + return nil, errors.Wrapf(err, "decoding family key: %s", keys.PrettyPrint(nil, partialKey)) + } + + tableDesc, err := r.fetchTableDesc(ctx, tableID, rfv.Value.Timestamp) + if err != nil { + return nil, errors.Wrapf(err, "fetching table descriptor: %s", keys.PrettyPrint(nil, partialKey)) + } + familyDesc, err := catalog.MustFindFamilyByID(tableDesc, descpb.FamilyID(familyID)) + if err != nil { + return nil, errors.Wrapf(err, "fetching family descriptor: %s", keys.PrettyPrint(nil, partialKey)) + } + cols, err := getRelevantColumnsForFamily(tableDesc, familyDesc) + if err != nil { + return nil, errors.Wrapf(err, "getting relevant columns for family: %s", keys.PrettyPrint(nil, partialKey)) + } + + var spec fetchpb.IndexFetchSpec + if err := rowenc.InitIndexFetchSpec(&spec, r.codec, tableDesc, tableDesc.GetPrimaryIndex(), cols); err != nil { + return nil, errors.Wrapf(err, "initializing index fetch spec: %s", keys.PrettyPrint(nil, partialKey)) + } + rf := row.Fetcher{} + if err := rf.Init(ctx, row.FetcherInitArgs{ + Spec: &spec, + WillUseKVProvider: true, + TraceKV: true, + TraceKVEvery: &util.EveryN{N: 1}, + }); err != nil { + return nil, errors.Wrapf(err, "initializing row fetcher: %s", keys.PrettyPrint(nil, partialKey)) + } + kvProvider := row.KVProvider{KVs: []roachpb.KeyValue{{Key: rfv.Key, Value: rfv.Value}}} + if err := rf.ConsumeKVProvider(ctx, &kvProvider); err != nil { + return nil, errors.Wrapf(err, "consuming kv provider: %s", keys.PrettyPrint(nil, partialKey)) + } + encDatums, _, err := rf.NextRow(ctx) + if err != nil { + return nil, errors.Wrapf(err, "fetching next row: %s", keys.PrettyPrint(nil, partialKey)) + } + _ = encDatums + + datums := make(tree.Datums, len(cols)) + for i, colID := range cols { + col, err := catalog.MustFindColumnByID(tableDesc, colID) + if err != nil { + return nil, errors.Wrapf(err, "finding column by id: %d", colID) + } + ed := encDatums[i] + if err := ed.EnsureDecoded(col.ColumnDesc().Type, &tree.DatumAlloc{}); err != nil { + return nil, errors.Wrapf(err, "error decoding column %q as type %s", col.ColumnDesc().Name, col.ColumnDesc().Type.String()) + } + datums[i] = ed.Datum + } + return datums, nil +} + +func (r *Reader) fetchTableDesc( + ctx context.Context, tableID descpb.ID, ts hlc.Timestamp, +) (catalog.TableDescriptor, error) { + // Retrieve the target TableDescriptor from the lease manager. No caching + // is attempted because the lease manager does its own caching. + desc, err := r.leaseMgr.Acquire(ctx, lease.TimestampToReadTimestamp(ts), tableID) + if err != nil { + // Manager can return all kinds of errors during chaos, but based on + // its usage, none of them should ever be terminal. + return nil, changefeedbase.MarkRetryableError(err) + } + tableDesc := desc.Underlying().(catalog.TableDescriptor) + // Immediately release the lease, since we only need it for the exact + // timestamp requested. + desc.Release(ctx) + if tableDesc.MaybeRequiresTypeHydration() { + return nil, errors.AssertionFailedf("type hydration not supported yet") + } + return tableDesc, nil +} + +var _ queuebase.Reader = &Reader{} + +func getRelevantColumnsForFamily( + tableDesc catalog.TableDescriptor, familyDesc *descpb.ColumnFamilyDescriptor, +) ([]descpb.ColumnID, error) { + cols := tableDesc.GetPrimaryIndex().CollectKeyColumnIDs() + for _, colID := range familyDesc.ColumnIDs { + cols.Add(colID) + } + + // Maintain the ordering of tableDesc.PublicColumns(), which is + // matches the order of columns in the SQL table. + idx := 0 + result := make([]descpb.ColumnID, cols.Len()) + visibleColumns := tableDesc.PublicColumns() + if tableDesc.GetDeclarativeSchemaChangerState() != nil { + hasMergedIndex := catalog.HasDeclarativeMergedPrimaryIndex(tableDesc) + visibleColumns = make([]catalog.Column, 0, cols.Len()) + for _, col := range tableDesc.AllColumns() { + if col.Adding() { + continue + } + if tableDesc.GetDeclarativeSchemaChangerState() == nil && !col.Public() { + continue + } + if col.Dropped() && (!col.WriteAndDeleteOnly() || hasMergedIndex) { + continue + } + visibleColumns = append(visibleColumns, col) + } + // Recover the order of the original columns. + slices.SortStableFunc(visibleColumns, func(a, b catalog.Column) int { + return int(a.GetPGAttributeNum()) - int(b.GetPGAttributeNum()) + }) + } + for _, col := range visibleColumns { + colID := col.GetID() + if cols.Contains(colID) { + result[idx] = colID + idx++ + } + } + + // Some columns in familyDesc.ColumnIDs may not be public, so + // result may contain fewer columns than cols. + result = result[:idx] + return result, nil +} diff --git a/pkg/sql/queuefeed/reader_test.go b/pkg/sql/queuefeed/reader_test.go new file mode 100644 index 000000000000..bab1f7f025b3 --- /dev/null +++ b/pkg/sql/queuefeed/reader_test.go @@ -0,0 +1,239 @@ +package queuefeed + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/hlc" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/uuid" + "github.com/stretchr/testify/require" +) + +func TestReaderBasic(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a STRING, b INT)`) + + var tableID int64 + db.QueryRow(t, "SELECT id FROM system.namespace WHERE name = 't'").Scan(&tableID) + + qm := NewTestManager(t, srv.ApplicationLayer()) + defer qm.Close() + require.NoError(t, qm.CreateQueue(ctx, "test_queue", tableID)) + + // These should be readable as long as they were written after the queue was created. + db.Exec(t, `INSERT INTO t VALUES ('row1', 10), ('row2', 20), ('row3', 30)`) + + reader, err := qm.CreateReaderForSession(ctx, "test_queue", Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID("1"), + }) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + + rows := pollForRows(t, ctx, reader, 3) + + requireRow(t, rows[0], "row1", 10) + requireRow(t, rows[1], "row2", 20) + requireRow(t, rows[2], "row3", 30) + reader.ConfirmReceipt(ctx) +} + +func TestReaderRollback(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a STRING, b INT)`) + + var tableID int64 + db.QueryRow(t, "SELECT id FROM system.namespace WHERE name = 't'").Scan(&tableID) + + qm := NewTestManager(t, srv.ApplicationLayer()) + defer qm.Close() + require.NoError(t, qm.CreateQueue(ctx, "rollback_test", tableID)) + + reader, err := qm.CreateReaderForSession(ctx, "rollback_test", Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID("1"), + }) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + + db.Exec(t, `INSERT INTO t VALUES ('row1', 100), ('row2', 200)`) + + rows1 := pollForRows(t, ctx, reader, 2) + + requireRow(t, rows1[0], "row1", 100) + requireRow(t, rows1[1], "row2", 200) + + reader.RollbackBatch(ctx) + + rows2, err := reader.GetRows(ctx, 10) + require.NoError(t, err) + require.Len(t, rows2, 2, "should get same 2 rows after rollback") + + requireRow(t, rows2[0], "row1", 100) + requireRow(t, rows2[1], "row2", 200) + + reader.ConfirmReceipt(ctx) + + db.Exec(t, `INSERT INTO t VALUES ('row3', 300), ('row4', 400)`) + + // Verify we got the NEW data (row3, row4), NOT the old data (row1, row2). + rows3 := pollForRows(t, ctx, reader, 2) + + requireRow(t, rows3[0], "row3", 300) + requireRow(t, rows3[1], "row4", 400) + + reader.ConfirmReceipt(ctx) +} + +func TestCheckpointRestoration(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a STRING, b INT)`) + + var tableID int64 + db.QueryRow(t, "SELECT id FROM system.namespace WHERE name = 't'").Scan(&tableID) + + qm := NewTestManager(t, srv.ApplicationLayer()) + defer qm.Close() + require.NoError(t, qm.CreateQueue(ctx, "checkpoint_test", tableID)) + + session := Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID("1"), + } + reader1, err := qm.CreateReaderForSession(ctx, "checkpoint_test", session) + require.NoError(t, err) + + db.Exec(t, `INSERT INTO t VALUES ('batch1_row1', 1), ('batch1_row2', 2)`) + + // Sleep to let the rangefeed checkpoint advance past the data timestamps. + // This is really ugly but with 3 seconds the test failed in 100% of runs. + time.Sleep(10 * time.Second) + + _ = pollForRows(t, ctx, reader1, 2) + + reader1.ConfirmReceipt(ctx) + require.NoError(t, reader1.Close()) + + db.Exec(t, `INSERT INTO t VALUES ('batch2_row1', 3), ('batch2_row2', 4)`) + + reader2, err := qm.CreateReaderForSession(ctx, "checkpoint_test", session) + require.NoError(t, err) + defer func() { _ = reader2.Close() }() + + rows2 := pollForRows(t, ctx, reader2, 2) + + // Verify we got ONLY the new data, not the old data. + // Check that none of the rows are from batch1. + for _, row := range rows2 { + val := getString(row[0]) + require.NotContains(t, val, "batch1", "should not see batch1 data after checkpoint") + require.Contains(t, val, "batch2", "should see batch2 data") + } +} + +func TestCreateQueueFeedFromCursor(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, conn, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(conn) + db.Exec(t, `CREATE TABLE t (a STRING, b INT)`) + + var tableID int64 + db.QueryRow(t, "SELECT id FROM system.namespace WHERE name = 't'").Scan(&tableID) + + // Insert first batch of data (should NOT be read). + db.Exec(t, `INSERT INTO t VALUES ('batch1_row1', 10), ('batch1_row2', 20)`) + + // Get cursor timestamp after first batch. + var cursorStr string + db.QueryRow(t, "SELECT cluster_logical_timestamp()").Scan(&cursorStr) + cursor, err := hlc.ParseHLC(cursorStr) + + // Insert second batch of data (should be read). + db.Exec(t, `INSERT INTO t VALUES ('batch2_row1', 30), ('batch2_row2', 40)`) + + qm := NewTestManager(t, srv.ApplicationLayer()) + defer qm.Close() + require.NoError(t, qm.CreateQueueFromCursor(ctx, "cursor_test", tableID, cursor)) + + reader, err := qm.CreateReaderForSession(ctx, "cursor_test", Session{ + ConnectionID: uuid.MakeV4(), + LivenessID: sqlliveness.SessionID("1"), + }) + require.NoError(t, err) + defer func() { _ = reader.Close() }() + + // Should only get the second batch. + rows := pollForRows(t, ctx, reader, 2) + requireRow(t, rows[0], "batch2_row1", 30) + requireRow(t, rows[1], "batch2_row2", 40) + + reader.ConfirmReceipt(ctx) +} + +// pollForRows waits for the reader to return expectedCount rows. +func pollForRows( + t *testing.T, ctx context.Context, reader queuebase.Reader, expectedCount int, +) []tree.Datums { + var rows []tree.Datums + require.Eventually(t, func() bool { + var err error + rows, err = reader.GetRows(ctx, 10) + require.NoError(t, err) + if len(rows) < expectedCount { + reader.RollbackBatch(ctx) + } + return len(rows) == expectedCount + }, 5*time.Second, 50*time.Millisecond, "expected %d rows", expectedCount) + return rows +} + +// getString extracts a string from a tree.Datum. +func getString(d tree.Datum) string { + return string(*d.(*tree.DString)) +} + +// getInt extracts an int64 from a tree.Datum. +func getInt(d tree.Datum) int64 { + return int64(*d.(*tree.DInt)) +} + +// requireRow asserts that a row matches the expected string and int values. +func requireRow(t *testing.T, row tree.Datums, expectedStr string, expectedInt int64) { + require.Equal(t, expectedStr, getString(row[0])) + require.Equal(t, expectedInt, getInt(row[1])) +} diff --git a/pkg/sql/queuefeed/smoke_test.go b/pkg/sql/queuefeed/smoke_test.go new file mode 100644 index 000000000000..baf43c0cc8d2 --- /dev/null +++ b/pkg/sql/queuefeed/smoke_test.go @@ -0,0 +1,170 @@ +package queuefeed + +import ( + "context" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestQueuefeedSmoketest(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(sqlDB) + db.Exec(t, `CREATE TABLE t (k string primary key)`) + _, err := srv.SystemLayer().SQLConn(t).Exec(`SET CLUSTER SETTING kv.rangefeed.enabled = true`) + require.NoError(t, err) + + db.Exec(t, `SELECT crdb_internal.create_queue_feed('test_queue', 't')`) + + // TODO improve this test once creating the queue sets an accurate cursor. We + // should be able to read an expected set of rows. + ctx, cancel := context.WithCancel(ctx) + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + for i := 0; ctx.Err() == nil; i++ { + t.Log("inserting row", i) + db.Exec(t, `INSERT INTO t VALUES ($1::STRING)`, i) + time.Sleep(100 * time.Millisecond) + } + return nil + }) + + conn, err := srv.SQLConn(t).Conn(context.Background()) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + // Try to read from the queue until we observe some data. The queue doesn't + // currently track the frontier, so we need to keep inserting data because + // there is a race between inserting and reading from the queue. + found := 0 + for found < 1 { + t.Log("reading from queue feed") + + cursor, err := conn.QueryContext(ctx, "SELECT * FROM crdb_internal.select_from_queue_feed('test_queue', 1)") + require.NoError(t, err) + + for cursor.Next() { + var k string + require.NoError(t, cursor.Scan(&k)) + found++ + } + + require.NoError(t, cursor.Err()) + require.NoError(t, cursor.Close()) + } + + cancel() + require.NoError(t, group.Wait()) +} + +func TestQueuefeedSmoketestMultipleReaders(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // TODO(jeffswenson): rewrite this test to use normal sessions. + + ctx := context.Background() + srv, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) + defer srv.Stopper().Stop(ctx) + + db := sqlutils.MakeSQLRunner(sqlDB) + _, err := srv.SystemLayer().SQLConn(t).Exec(`SET CLUSTER SETTING kv.rangefeed.enabled = true`) + require.NoError(t, err) + + // Create table with composite primary key and split it + db.Exec(t, `CREATE TABLE t (k1 INT, k2 INT, v string, PRIMARY KEY (k1, k2))`) + db.Exec(t, `ALTER TABLE t SPLIT AT VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9)`) + + db.Exec(t, `SELECT crdb_internal.create_queue_feed('t_queue', 't')`) + + ctx, cancel := context.WithCancel(ctx) + group, ctx := errgroup.WithContext(ctx) + + group.Go(func() error { + for i := 0; ctx.Err() == nil; i++ { + _, err := sqlDB.ExecContext(ctx, `INSERT INTO t VALUES ($1, $2)`, i%10, rand.Int()) + if err != nil { + return errors.Wrap(err, "inserting a row") + } + } + return ctx.Err() + }) + + numWriters := rand.Intn(10) + 1 // create [1, 10] writers + rowsSeen := make([]atomic.Int64, numWriters) + t.Logf("spawning %d readers", numWriters) + for i := range numWriters { + group.Go(func() error { + time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) + + conn, err := srv.SQLConn(t).Conn(ctx) + if err != nil { + return err + } + + for ctx.Err() == nil { + cursor, err := conn.QueryContext(ctx, `SELECT * FROM crdb_internal.select_from_queue_feed('t_queue', 1000)`) + if err != nil { + return err + } + for cursor.Next() { + var e string + if err := cursor.Scan(&e); err != nil { + return errors.Wrap(err, "scanning queue feed row") + } + rowsSeen[i].Add(1) + } + require.NoError(t, cursor.Close()) + } + + return ctx.Err() + }) + } + + // Wait for every reader to observe rows and every partition to be assigned. + testutils.SucceedsSoon(t, func() error { + for _, row := range db.QueryStr(t, "SELECT partition_id, user_session, user_session_successor FROM defaultdb.queue_partition_t_queue") { + t.Logf("partition row: %v", row) + } + + seen := make([]int64, numWriters) + for i := range numWriters { + seen[i] = rowsSeen[i].Load() + } + t.Logf("row counts %v", seen) + + for i := range numWriters { + if seen[i] == 0 { + return errors.Newf("reader %d has not seen any rows yet", i) + } + } + + var unassignedPartitions int + db.QueryRow(t, "SELECT COUNT(*) FROM defaultdb.queue_partition_t_queue WHERE user_session IS NULL").Scan(&unassignedPartitions) + if unassignedPartitions != 0 { + return errors.Newf("%d unassigned partitions remain", unassignedPartitions) + } + + return nil + }) + + cancel() + _ = group.Wait() +} diff --git a/pkg/sql/sem/builtins/BUILD.bazel b/pkg/sql/sem/builtins/BUILD.bazel index 06c3e60d572e..98598ebe3ba2 100644 --- a/pkg/sql/sem/builtins/BUILD.bazel +++ b/pkg/sql/sem/builtins/BUILD.bazel @@ -78,6 +78,7 @@ go_library( "//pkg/sql/pgwire/pgnotice", "//pkg/sql/privilege", "//pkg/sql/protoreflect", + "//pkg/sql/queuefeed/queuebase", "//pkg/sql/rowenc", "//pkg/sql/rowenc/keyside", "//pkg/sql/rowenc/valueside", diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index 6cfd5c22ece8..0a94e1bd0ae2 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -55,6 +55,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" "github.com/cockroachdb/cockroach/pkg/sql/protoreflect" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/rowenc" "github.com/cockroachdb/cockroach/pkg/sql/rowenc/keyside" "github.com/cockroachdb/cockroach/pkg/sql/sem/asof" @@ -4642,6 +4643,104 @@ value if you rely on the HLC for accuracy.`, } }()...), + "crdb_internal.create_queue_feed": makeBuiltin(defProps(), tree.Overload{ + Types: tree.ParamTypes{ + {Name: "queue_name", Typ: types.String}, + {Name: "table_name", Typ: types.String}, + }, + Volatility: volatility.Volatile, + ReturnType: tree.FixedReturnType(types.Void), + Fn: func(ctx context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + qn := args[0].(*tree.DString) + tableName := tree.MustBeDString(args[1]) + dOid, err := eval.ParseDOid(ctx, evalCtx, string(tableName), types.RegClass) + if err != nil { + return nil, err + } + + qm := getQueueManager(evalCtx) + if err := qm.CreateQueue(ctx, string(*qn), int64(dOid.Oid)); err != nil { + return nil, err + } + return tree.DVoidDatum, nil + }, + }), + + "crdb_internal.create_queue_feed_from_cursor": makeBuiltin(defProps(), tree.Overload{ + Types: tree.ParamTypes{ + {Name: "queue_name", Typ: types.String}, + {Name: "table_name", Typ: types.String}, + {Name: "cursor", Typ: types.Decimal}, + }, + Volatility: volatility.Volatile, + ReturnType: tree.FixedReturnType(types.Void), + Fn: func(ctx context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + qn := args[0].(*tree.DString) + + tableName := tree.MustBeDString(args[1]) + dOid, err := eval.ParseDOid(ctx, evalCtx, string(tableName), types.RegClass) + if err != nil { + return nil, err + } + + cursorDecimal := tree.MustBeDDecimal(args[2]) + cursor, err := hlc.DecimalToHLC(&cursorDecimal.Decimal) + if err != nil { + return nil, errors.Wrap(err, "converting cursor decimal to HLC timestamp") + } + + qm := getQueueManager(evalCtx) + if err := qm.CreateQueueFromCursor(ctx, string(*qn), int64(dOid.Oid), cursor); err != nil { + return nil, err + } + return tree.DVoidDatum, nil + }, + }), + + "crdb_internal.select_array_from_queue_feed": makeBuiltin(defProps(), tree.Overload{ + Types: tree.ParamTypes{ + {Name: "queue_name", Typ: types.String}, + {Name: "limit", Typ: types.Int}, + }, + Volatility: volatility.Volatile, + ReturnType: tree.FixedReturnType(types.MakeArray(types.Json)), + Fn: func(ctx context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + qn := args[0].(*tree.DString) + qr, err := getQueueReaderProvider(evalCtx).GetOrInitReader(evalCtx.SessionCtx, string(*qn)) + if err != nil { + return nil, errors.Wrapf(err, "get or init reader for queue %s", string(*qn)) + } + + ret := tree.NewDArray(types.Json) + + rowResult, err := qr.GetRows(ctx, int(tree.MustBeDInt(args[1]))) + if err != nil { + return nil, err + } + // attach commit hook to txn to confirm receipt + // or something... todo on rollback/abort + evalCtx.Txn.AddCommitTrigger(func(ctx context.Context) { + qr.ConfirmReceipt(ctx) + }) + + for _, row := range rowResult { + obj := json.NewObjectBuilder(len(row)) + for i, d := range row { + fmt.Printf("d: %#+v\n", d) + j, err := tree.AsJSON(d, evalCtx.SessionData().DataConversionConfig, evalCtx.GetLocation()) + if err != nil { + return nil, err + } + obj.Add(fmt.Sprintf("f%d", i+1), j) + } + if err := ret.Append(tree.NewDJSON(obj.Build())); err != nil { + return nil, err + } + } + return ret, nil + }, + }), + "crdb_internal.json_to_pb": makeBuiltin( jsonProps(), tree.Overload{ @@ -12846,3 +12945,11 @@ func exprSliceToStrSlice(exprs []tree.Expr) []string { } var nilRegionsError = errors.AssertionFailedf("evalCtx.Regions is nil") + +func getQueueManager(evalCtx *eval.Context) queuebase.Manager { + return evalCtx.Planner.ExecutorConfig().(interface{ GetQueueManager() queuebase.Manager }).GetQueueManager() +} + +func getQueueReaderProvider(evalCtx *eval.Context) queuebase.ReaderProvider { + return evalCtx.Planner.GetQueueReaderProvider() +} diff --git a/pkg/sql/sem/builtins/fixed_oids.go b/pkg/sql/sem/builtins/fixed_oids.go index 793e0381f80b..dd1aa3f035c2 100644 --- a/pkg/sql/sem/builtins/fixed_oids.go +++ b/pkg/sql/sem/builtins/fixed_oids.go @@ -2863,6 +2863,10 @@ var builtinOidsArray = []string{ 2908: `crdb_internal.inject_hint(statement_fingerprint: string, donor_sql: string) -> int`, 2909: `crdb_internal.clear_statement_hints_cache() -> void`, 2910: `crdb_internal.await_statement_hints_cache() -> void`, + 2911: `crdb_internal.create_queue_feed(queue_name: string, table_name: string) -> void`, + 2912: `crdb_internal.select_from_queue_feed(queue_name: string, limit: int) -> jsonb`, + 2913: `crdb_internal.select_array_from_queue_feed(queue_name: string, limit: int) -> jsonb[]`, + 2914: `crdb_internal.create_queue_feed_from_cursor(queue_name: string, table_name: string, cursor: decimal) -> void`, } var builtinOidsBySignature map[string]oid.Oid diff --git a/pkg/sql/sem/builtins/generator_builtins.go b/pkg/sql/sem/builtins/generator_builtins.go index 5c45bbc6ad7d..536579d57e8a 100644 --- a/pkg/sql/sem/builtins/generator_builtins.go +++ b/pkg/sql/sem/builtins/generator_builtins.go @@ -9,6 +9,7 @@ import ( "bytes" "context" gojson "encoding/json" + "fmt" "math/rand" "sort" "strconv" @@ -141,6 +142,19 @@ var generators = map[string]builtinDefinition{ volatility.Stable, ), ), + "crdb_internal.select_from_queue_feed": makeBuiltin( + genProps(), + makeGeneratorOverload( + tree.ParamTypes{ + {Name: "queue_name", Typ: types.String}, + {Name: "limit", Typ: types.Int}, + }, + queueFeedGeneratorType, + makeQueueFeedGenerator, + "Returns rows from a queue feed", + volatility.Volatile, + ), + ), "crdb_internal.scan": makeBuiltin( tree.FunctionProperties{ Category: builtinconstants.CategoryGenerator, @@ -4350,3 +4364,79 @@ func (g *txnDiagnosticsRequestGenerator) Values() (tree.Datums, error) { // Close implements the eval.ValueGenerator interface. func (g *txnDiagnosticsRequestGenerator) Close(ctx context.Context) { } + +type queueFeedGenerator struct { + queueName string + limit int + evalCtx *eval.Context + rows []tree.Datums + rowIdx int +} + +var queueFeedGeneratorType = types.Jsonb + +func makeQueueFeedGenerator( + ctx context.Context, evalCtx *eval.Context, args tree.Datums, +) (eval.ValueGenerator, error) { + queueName := string(tree.MustBeDString(args[0])) + limit := int(tree.MustBeDInt(args[1])) + return &queueFeedGenerator{ + queueName: queueName, + limit: limit, + evalCtx: evalCtx, + rowIdx: -1, + }, nil +} + +// ResolvedType implements the eval.ValueGenerator interface. +func (g *queueFeedGenerator) ResolvedType() *types.T { + return queueFeedGeneratorType +} + +// Start implements the eval.ValueGenerator interface. +func (g *queueFeedGenerator) Start(ctx context.Context, txn *kv.Txn) error { + qr, err := getQueueReaderProvider(g.evalCtx).GetOrInitReader(g.evalCtx.SessionCtx, g.queueName) + if err != nil { + return err + } + + // Attach commit hook to txn to confirm receipt on successful commit. + txn.AddCommitTrigger(func(ctx context.Context) { + qr.ConfirmReceipt(ctx) + }) + // On rollback, we don't confirm receipt since the transaction didn't commit + // and the rows shouldn't be considered consumed. + txn.AddRollbackTrigger(func(ctx context.Context) { + qr.RollbackBatch(ctx) + }) + + rows, err := qr.GetRows(ctx, g.limit) + if err != nil { + return err + } + g.rows = rows + return nil +} + +// Next implements the eval.ValueGenerator interface. +func (g *queueFeedGenerator) Next(ctx context.Context) (bool, error) { + g.rowIdx++ + return g.rowIdx < len(g.rows), nil +} + +// Values implements the eval.ValueGenerator interface. +func (g *queueFeedGenerator) Values() (tree.Datums, error) { + row := g.rows[g.rowIdx] + obj := json.NewObjectBuilder(len(row)) + for i, d := range row { + j, err := tree.AsJSON(d, g.evalCtx.SessionData().DataConversionConfig, g.evalCtx.GetLocation()) + if err != nil { + return nil, err + } + obj.Add(fmt.Sprintf("f%d", i+1), j) + } + return tree.Datums{tree.NewDJSON(obj.Build())}, nil +} + +// Close implements the eval.ValueGenerator interface. +func (g *queueFeedGenerator) Close(ctx context.Context) {} diff --git a/pkg/sql/sem/eval/BUILD.bazel b/pkg/sql/sem/eval/BUILD.bazel index e007a41e4805..df59d0f62dfb 100644 --- a/pkg/sql/sem/eval/BUILD.bazel +++ b/pkg/sql/sem/eval/BUILD.bazel @@ -65,6 +65,7 @@ go_library( "//pkg/sql/pgwire/pgnotice", "//pkg/sql/pgwire/pgwirecancel", "//pkg/sql/privilege", + "//pkg/sql/queuefeed/queuebase", "//pkg/sql/sem/builtins/builtinsregistry", "//pkg/sql/sem/cast", "//pkg/sql/sem/catid", diff --git a/pkg/sql/sem/eval/context.go b/pkg/sql/sem/eval/context.go index 38109af21cb5..bcd33dc58ade 100644 --- a/pkg/sql/sem/eval/context.go +++ b/pkg/sql/sem/eval/context.go @@ -67,6 +67,13 @@ var ErrNilTxnInClusterContext = errors.New("nil txn in cluster context") // more fields from the sql package. Through that extendedEvalContext, this // struct now generally used by planNodes. type Context struct { + // SessionCtx is the session-lifetime context for the current SQL connection. + // It reflects the session's context (or the session tracing context if + // tracing is enabled) and is not cancelled at statement end. Prefer the + // statement/transaction-scoped ctx passed to functions when work must obey + // statement/txn cancellation; use SessionCtx only for work that must outlive + // the statement/txn but remain tied to the SQL session. + SessionCtx context.Context // SessionDataStack stores the session variables accessible by the correct // context. Each element on the stack represents the beginning of a new // transaction or nested transaction (savepoints). @@ -320,6 +327,8 @@ type Context struct { // ExecutedStatementCounters contains metrics for successfully executed // statements defined within the body of a UDF/SP. ExecutedRoutineStatementCounters RoutineStatementCounters + + QueueSessionMgr any } // RoutineStatementCounters encapsulates metrics for tracking the execution diff --git a/pkg/sql/sem/eval/deps.go b/pkg/sql/sem/eval/deps.go index 34f6de9bd4ba..573217051f24 100644 --- a/pkg/sql/sem/eval/deps.go +++ b/pkg/sql/sem/eval/deps.go @@ -18,6 +18,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/hintpb" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgnotice" "github.com/cockroachdb/cockroach/pkg/sql/privilege" + "github.com/cockroachdb/cockroach/pkg/sql/queuefeed/queuebase" "github.com/cockroachdb/cockroach/pkg/sql/sem/catid" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" @@ -297,6 +298,9 @@ type Planner interface { // the `system.users` table UserHasAdminRole(ctx context.Context, user username.SQLUsername) (bool, error) + // GetQueueReaderProvider returns the ReaderProvider for queuefeed readers. + GetQueueReaderProvider() queuebase.ReaderProvider + // MemberOfWithAdminOption is used to collect a list of roles (direct and // indirect) that the member is part of. See the comment on the planner // implementation in authorization.go