Skip to content

Commit 73c3fb0

Browse files
authored
feat: add ExecuteBatchDml func with batch result (#454)
Adds an ExecuteBatchDml function that returns the actual batch result. This function can be used to execute a batch of DML statements and get back the exact number of rows affected per statement, instead of only the total number of rows affected by the statements in the batch. Fixes #377
1 parent 3aa5ca2 commit 73c3fb0

File tree

5 files changed

+338
-3
lines changed

5 files changed

+338
-3
lines changed

conn.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,11 @@ func (c *conn) runDDLBatch(ctx context.Context) (driver.Result, error) {
521521
return c.execDDL(ctx, statements...)
522522
}
523523

524-
func (c *conn) runDMLBatch(ctx context.Context) (driver.Result, error) {
524+
func (c *conn) runDMLBatch(ctx context.Context) (SpannerResult, error) {
525+
if c.inTransaction() {
526+
return c.tx.RunDmlBatch(ctx)
527+
}
528+
525529
statements := c.batch.statements
526530
options := c.batch.options
527531
options.QueryOptions.LastStatement = true
@@ -566,7 +570,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr
566570
return driver.ResultNoRows, nil
567571
}
568572

569-
func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options ExecOptions) (driver.Result, error) {
573+
func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options ExecOptions) (SpannerResult, error) {
570574
if len(statements) == 0 {
571575
return &result{}, nil
572576
}
@@ -585,7 +589,7 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement,
585589
return err
586590
}, options.TransactionOptions)
587591
}
588-
return &result{rowsAffected: sum(affected)}, err
592+
return &result{rowsAffected: sum(affected), batchUpdateCounts: affected}, err
589593
}
590594

591595
func sum(affected []int64) int64 {

driver.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,18 @@ func determineDefaultStatementCacheSize() {
112112
}
113113
}
114114

115+
// SpannerResult is the result type returned by Spanner connections for
116+
// DML batches. This interface extends the standard sql.Result interface
117+
// and adds a BatchRowsAffected function that returns the affected rows
118+
// per statement.
119+
type SpannerResult interface {
120+
driver.Result
121+
122+
// BatchRowsAffected returns the affected rows per statement in a DML batch.
123+
// It returns an error if the statement was not a DML batch.
124+
BatchRowsAffected() ([]int64, error)
125+
}
126+
115127
// ExecOptions can be passed in as an argument to the Query, QueryContext,
116128
// Exec, and ExecContext functions to specify additional execution options
117129
// for a statement.
@@ -1033,6 +1045,94 @@ func clearTempReadOnlyTransactionOptions(conn *sql.Conn) {
10331045
_ = conn.Close()
10341046
}
10351047

1048+
// DmlBatch is used to execute a batch of DML statements on Spanner in a single round-trip.
1049+
type DmlBatch interface {
1050+
// ExecContext buffers the given statement for execution on Spanner.
1051+
// All buffered statements are sent to Spanner as a single request when the DmlBatch
1052+
// function returns successfully.
1053+
ExecContext(ctx context.Context, dml string, args ...any) error
1054+
}
1055+
1056+
var _ DmlBatch = &dmlBatch{}
1057+
1058+
type dmlBatch struct {
1059+
conn *sql.Conn
1060+
}
1061+
1062+
// ExecuteBatchDml executes a batch of DML statements in a single round-trip to Spanner.
1063+
func ExecuteBatchDml(ctx context.Context, db *sql.DB, f func(ctx context.Context, batch DmlBatch) error) (SpannerResult, error) {
1064+
conn, err := db.Conn(ctx)
1065+
if err != nil {
1066+
return nil, err
1067+
}
1068+
return ExecuteBatchDmlOnConn(ctx, conn, f)
1069+
}
1070+
1071+
// ExecuteBatchDmlOnConn executes a batch of DML statements on a specific connection in a single round-trip to Spanner.
1072+
func ExecuteBatchDmlOnConn(ctx context.Context, connection *sql.Conn, f func(ctx context.Context, batch DmlBatch) error) (SpannerResult, error) {
1073+
// Start the DML batch.
1074+
if err := connection.Raw(func(driverConn any) error {
1075+
c, ok := driverConn.(*conn)
1076+
if !ok {
1077+
return spanner.ToSpannerError(status.Error(codes.InvalidArgument, "connection is not a Spanner connection"))
1078+
}
1079+
if _, err := c.startBatchDML(false); err != nil {
1080+
return err
1081+
}
1082+
return nil
1083+
}); err != nil {
1084+
return nil, err
1085+
}
1086+
1087+
// Let the callback execute the statements on the batch.
1088+
b := &dmlBatch{conn: connection}
1089+
if err := f(ctx, b); err != nil {
1090+
// The callback returned an error, abort the batch.
1091+
_ = connection.Raw(func(driverConn any) error {
1092+
c, _ := driverConn.(*conn)
1093+
_ = c.AbortBatch()
1094+
return nil
1095+
})
1096+
return nil, err
1097+
}
1098+
1099+
// Send the batch to Spanner.
1100+
var res SpannerResult
1101+
if err := connection.Raw(func(driverConn any) error {
1102+
// We know that the connection is a Spanner connection, so we don't bother to check that again here.
1103+
c, _ := driverConn.(*conn)
1104+
var err error
1105+
res, err = c.runDMLBatch(ctx)
1106+
if err != nil {
1107+
// Make sure the batch is removed from the connection/transaction.
1108+
_ = c.AbortBatch()
1109+
return err
1110+
}
1111+
return nil
1112+
}); err != nil {
1113+
return nil, err
1114+
}
1115+
1116+
return res, nil
1117+
}
1118+
1119+
func (b *dmlBatch) ExecContext(ctx context.Context, dml string, args ...any) error {
1120+
if err := b.conn.Raw(func(driverConn any) error {
1121+
c, ok := driverConn.(*conn)
1122+
if !ok {
1123+
return spanner.ToSpannerError(status.Error(codes.InvalidArgument, "connection is not a Spanner connection"))
1124+
}
1125+
if !c.InDMLBatch() {
1126+
return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "this batch is no longer active"))
1127+
}
1128+
return nil
1129+
}); err != nil {
1130+
return err
1131+
}
1132+
_, err := b.conn.ExecContext(ctx, dml, args...)
1133+
return err
1134+
}
1135+
10361136
// AutocommitDMLMode indicates whether a single DML statement should be executed
10371137
// in a normal atomic transaction or as a Partitioned DML statement.
10381138
// See https://cloud.google.com/spanner/docs/dml-partitioned for more information.

driver_with_mockserver_test.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,133 @@ func TestAutocommitBatchDml(t *testing.T) {
27552755
}
27562756
}
27572757

2758+
func TestExecuteBatchDml(t *testing.T) {
2759+
t.Parallel()
2760+
2761+
ctx := context.Background()
2762+
db, server, teardown := setupTestDBConnection(t)
2763+
defer teardown()
2764+
2765+
_ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{
2766+
Type: testutil.StatementResultUpdateCount,
2767+
UpdateCount: 1,
2768+
})
2769+
_ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{
2770+
Type: testutil.StatementResultUpdateCount,
2771+
UpdateCount: 1,
2772+
})
2773+
2774+
res, err := ExecuteBatchDml(ctx, db, func(ctx context.Context, batch DmlBatch) error {
2775+
if err := batch.ExecContext(ctx, "INSERT INTO Foo (Id, Val) VALUES (1, 'One')"); err != nil {
2776+
return err
2777+
}
2778+
if err := batch.ExecContext(ctx, "INSERT INTO Foo (Id, Val) VALUES (2, 'Two')"); err != nil {
2779+
return err
2780+
}
2781+
return nil
2782+
})
2783+
if err != nil {
2784+
t.Fatalf("failed to execute dml batch: %v", err)
2785+
}
2786+
affected, err := res.RowsAffected()
2787+
if err != nil {
2788+
t.Fatalf("could not get rows affected from batch: %v", err)
2789+
}
2790+
if g, w := affected, int64(2); g != w {
2791+
t.Fatalf("affected rows mismatch\n Got: %v\nWant: %v", g, w)
2792+
}
2793+
batchAffected, err := res.BatchRowsAffected()
2794+
if err != nil {
2795+
t.Fatalf("could not get batch rows affected from batch: %v", err)
2796+
}
2797+
if g, w := batchAffected, []int64{1, 1}; !cmp.Equal(g, w) {
2798+
t.Fatalf("affected batch rows mismatch\n Got: %v\nWant: %v", g, w)
2799+
}
2800+
2801+
requests := drainRequestsFromServer(server.TestSpanner)
2802+
// There should be no ExecuteSqlRequests on the server.
2803+
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
2804+
if g, w := len(sqlRequests), 0; g != w {
2805+
t.Fatalf("sql requests count mismatch\n Got: %v\nWant: %v", g, w)
2806+
}
2807+
batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{}))
2808+
if g, w := len(batchRequests), 1; g != w {
2809+
t.Fatalf("BatchDML requests count mismatch\n Got: %v\nWant: %v", g, w)
2810+
}
2811+
if !batchRequests[0].(*sppb.ExecuteBatchDmlRequest).LastStatements {
2812+
t.Fatal("last statements flag not set")
2813+
}
2814+
// The transaction should have been committed.
2815+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
2816+
if g, w := len(commitRequests), 1; g != w {
2817+
t.Fatalf("Commit requests count mismatch\n Got: %v\nWant: %v", g, w)
2818+
}
2819+
}
2820+
2821+
func TestExecuteBatchDmlError(t *testing.T) {
2822+
t.Parallel()
2823+
2824+
ctx := context.Background()
2825+
db, server, teardown := setupTestDBConnection(t)
2826+
defer teardown()
2827+
2828+
_ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{
2829+
Type: testutil.StatementResultUpdateCount,
2830+
UpdateCount: 1,
2831+
})
2832+
c, err := db.Conn(ctx)
2833+
defer func() { _ = c.Close() }()
2834+
if err != nil {
2835+
t.Fatalf("failed to obtain connection: %v", err)
2836+
}
2837+
2838+
_, err = ExecuteBatchDmlOnConn(ctx, c, func(ctx context.Context, batch DmlBatch) error {
2839+
if err := batch.ExecContext(ctx, "INSERT INTO Foo (Id, Val) VALUES (1, 'One')"); err != nil {
2840+
return err
2841+
}
2842+
return fmt.Errorf("test error")
2843+
})
2844+
if err == nil {
2845+
t.Fatalf("failed to execute dml batch: %v", err)
2846+
}
2847+
2848+
requests := drainRequestsFromServer(server.TestSpanner)
2849+
// There should be no requests on the server.
2850+
batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{}))
2851+
if g, w := len(batchRequests), 0; g != w {
2852+
t.Fatalf("BatchDML requests count mismatch\n Got: %v\nWant: %v", g, w)
2853+
}
2854+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
2855+
if g, w := len(commitRequests), 0; g != w {
2856+
t.Fatalf("Commit requests count mismatch\n Got: %v\nWant: %v", g, w)
2857+
}
2858+
2859+
// Verify that the connection is not in a batch, and that it can be used for other statements.
2860+
if err := c.Raw(func(driverConn any) error {
2861+
spannerConn, ok := driverConn.(SpannerConn)
2862+
if !ok {
2863+
return fmt.Errorf("driver connection is not a SpannerConn")
2864+
}
2865+
if spannerConn.InDMLBatch() {
2866+
return fmt.Errorf("connection is still in a batch")
2867+
}
2868+
return nil
2869+
}); err != nil {
2870+
t.Fatalf("check if connection is in a batch failed: %v", err)
2871+
}
2872+
res, err := c.ExecContext(ctx, `INSERT INTO Foo (Id, Val) VALUES (1, 'One')`)
2873+
if err != nil {
2874+
t.Fatalf("failed to execute dml statement: %v", err)
2875+
}
2876+
if affected, err := res.RowsAffected(); err != nil {
2877+
t.Fatalf("failed to obtain rows affected: %v", err)
2878+
} else {
2879+
if g, w := affected, int64(1); g != w {
2880+
t.Fatalf("affected rows mismatch\n Got: %v\nWant: %v", g, w)
2881+
}
2882+
}
2883+
}
2884+
27582885
func TestTransactionBatchDml(t *testing.T) {
27592886
t.Parallel()
27602887

@@ -2872,6 +2999,80 @@ func TestTransactionBatchDml(t *testing.T) {
28722999
}
28733000
}
28743001

3002+
func TestExecuteBatchDmlTransaction(t *testing.T) {
3003+
t.Parallel()
3004+
3005+
ctx := context.Background()
3006+
db, server, teardown := setupTestDBConnection(t)
3007+
defer teardown()
3008+
3009+
_ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (1, 'One')", &testutil.StatementResult{
3010+
Type: testutil.StatementResultUpdateCount,
3011+
UpdateCount: 1,
3012+
})
3013+
_ = server.TestSpanner.PutStatementResult("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')", &testutil.StatementResult{
3014+
Type: testutil.StatementResultUpdateCount,
3015+
UpdateCount: 1,
3016+
})
3017+
3018+
conn, err := db.Conn(ctx)
3019+
if err != nil {
3020+
t.Fatalf("failed to obtain connection: %v", err)
3021+
}
3022+
tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
3023+
if err != nil {
3024+
t.Fatalf("failed to begin transaction: %v", err)
3025+
}
3026+
res, err := ExecuteBatchDmlOnConn(ctx, conn, func(ctx context.Context, batch DmlBatch) error {
3027+
if err := batch.ExecContext(ctx, "INSERT INTO Foo (Id, Val) VALUES (1, 'One')"); err != nil {
3028+
return err
3029+
}
3030+
if err := batch.ExecContext(ctx, "INSERT INTO Foo (Id, Val) VALUES (2, 'Two')"); err != nil {
3031+
return err
3032+
}
3033+
return nil
3034+
})
3035+
if err != nil {
3036+
t.Fatalf("failed to execute dml batch: %v", err)
3037+
}
3038+
affected, err := res.RowsAffected()
3039+
if err != nil {
3040+
t.Fatalf("could not get rows affected from batch: %v", err)
3041+
}
3042+
if g, w := affected, int64(2); g != w {
3043+
t.Fatalf("affected rows mismatch\n Got: %v\nWant: %v", g, w)
3044+
}
3045+
batchAffected, err := res.BatchRowsAffected()
3046+
if err != nil {
3047+
t.Fatalf("could not get batch rows affected from batch: %v", err)
3048+
}
3049+
if g, w := batchAffected, []int64{1, 1}; !cmp.Equal(g, w) {
3050+
t.Fatalf("affected batch rows mismatch\n Got: %v\nWant: %v", g, w)
3051+
}
3052+
3053+
if err := tx.Commit(); err != nil {
3054+
t.Fatalf("failed to commit transaction after batch: %v", err)
3055+
}
3056+
3057+
requests := drainRequestsFromServer(server.TestSpanner)
3058+
// There should be no ExecuteSqlRequests on the server.
3059+
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
3060+
if g, w := len(sqlRequests), 0; g != w {
3061+
t.Fatalf("sql requests count mismatch\n Got: %v\nWant: %v", g, w)
3062+
}
3063+
batchRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{}))
3064+
if g, w := len(batchRequests), 1; g != w {
3065+
t.Fatalf("BatchDML requests count mismatch\n Got: %v\nWant: %v", g, w)
3066+
}
3067+
if batchRequests[0].(*sppb.ExecuteBatchDmlRequest).LastStatements {
3068+
t.Fatal("last statements flag was set, this should not happen for batches in a transaction")
3069+
}
3070+
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
3071+
if g, w := len(commitRequests), 1; g != w {
3072+
t.Fatalf("Commit requests count mismatch\n Got: %v\nWant: %v", g, w)
3073+
}
3074+
}
3075+
28753076
func TestCommitTimestamp(t *testing.T) {
28763077
t.Parallel()
28773078

stmt.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ func convertParam(v driver.Value) driver.Value {
223223
}
224224
}
225225

226+
var _ SpannerResult = &result{}
227+
226228
type result struct {
227229
rowsAffected int64
228230
lastInsertId int64
@@ -246,3 +248,12 @@ func (r *result) LastInsertId() (int64, error) {
246248
func (r *result) RowsAffected() (int64, error) {
247249
return r.rowsAffected, nil
248250
}
251+
252+
var errNoBatchRowsAffected = spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "BatchRowsAffected is only supported for batch DML results"))
253+
254+
func (r *result) BatchRowsAffected() ([]int64, error) {
255+
if r.batchUpdateCounts == nil {
256+
return nil, errNoBatchRowsAffected
257+
}
258+
return r.batchUpdateCounts, nil
259+
}

0 commit comments

Comments
 (0)