Skip to content

Commit 12ead7c

Browse files
authored
feat: add RunDmlBatch function with typed return value (#471)
Adds a custom RunDmlBatch function to SpannerConn that returns the actual number of affected rows per DML statement in a DML batch.
1 parent 19238ce commit 12ead7c

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

conn.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ type SpannerConn interface {
5959
// RunBatch sends all batched DDL or DML statements to Spanner. This is a
6060
// no-op if no statements have been batched or if there is no active batch.
6161
RunBatch(ctx context.Context) error
62+
// RunDmlBatch sends all batched DML statements to Spanner. This is a
63+
// no-op if no statements have been batched or if there is no active DML batch.
64+
RunDmlBatch(ctx context.Context) (SpannerResult, error)
6265
// AbortBatch aborts the current DDL or DML batch and discards all batched
6366
// statements.
6467
AbortBatch() error
@@ -446,6 +449,18 @@ func (c *conn) RunBatch(ctx context.Context) error {
446449
return err
447450
}
448451

452+
func (c *conn) RunDmlBatch(ctx context.Context) (SpannerResult, error) {
453+
res, err := c.runBatch(ctx)
454+
if err != nil {
455+
return nil, err
456+
}
457+
spannerRes, ok := res.(SpannerResult)
458+
if !ok {
459+
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "not a DML batch"))
460+
}
461+
return spannerRes, nil
462+
}
463+
449464
func (c *conn) AbortBatch() error {
450465
_, err := c.abortBatch()
451466
return err

conn_with_mockserver_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,44 @@ func TestDDLUsingQueryContextInReadWriteTransaction(t *testing.T) {
369369
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
370370
}
371371
}
372+
373+
func TestRunDmlBatch(t *testing.T) {
374+
t.Parallel()
375+
376+
db, _, teardown := setupTestDBConnection(t)
377+
defer teardown()
378+
ctx := context.Background()
379+
380+
conn, err := db.Conn(ctx)
381+
if err != nil {
382+
t.Fatal(err)
383+
}
384+
defer silentClose(conn)
385+
if err := conn.Raw(func(driverConn interface{}) error {
386+
spannerConn, _ := driverConn.(SpannerConn)
387+
return spannerConn.StartBatchDML()
388+
}); err != nil {
389+
t.Fatal(err)
390+
}
391+
// Buffer two DML statements.
392+
for range 2 {
393+
if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
394+
t.Fatal(err)
395+
}
396+
}
397+
var res SpannerResult
398+
if err := conn.Raw(func(driverConn interface{}) (err error) {
399+
spannerConn, _ := driverConn.(SpannerConn)
400+
res, err = spannerConn.RunDmlBatch(ctx)
401+
return err
402+
}); err != nil {
403+
t.Fatal(err)
404+
}
405+
affected, err := res.BatchRowsAffected()
406+
if err != nil {
407+
t.Fatal(err)
408+
}
409+
if g, w := affected, []int64{testutil.UpdateBarSetFooRowCount, testutil.UpdateBarSetFooRowCount}; !reflect.DeepEqual(g, w) {
410+
t.Fatalf("affected mismatch\n Got: %v\nWant: %v", g, w)
411+
}
412+
}

examples/dml-batches/main.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,17 @@ func dmlBatch(projectId, instanceId, databaseId string) error {
106106
return fmt.Errorf("failed to insert: %v", err)
107107
}
108108
// Run the batch. This will apply all the batched DML statements to the database in one atomic operation.
109-
if err := conn.Raw(func(driverConn interface{}) error {
110-
return driverConn.(spannerdriver.SpannerConn).RunBatch(ctx)
109+
var res spannerdriver.SpannerResult
110+
if err := conn.Raw(func(driverConn interface{}) (err error) {
111+
res, err = driverConn.(spannerdriver.SpannerConn).RunDmlBatch(ctx)
112+
return err
111113
}); err != nil {
112114
return fmt.Errorf("failed to run DML batch: %v", err)
113115
}
116+
// BatchRowsAffected returns a slice with the affected rows per DML statement in the batch.
117+
affected, _ := res.BatchRowsAffected()
118+
fmt.Printf("Affected rows: %v\n", affected)
119+
114120
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM Singers").Scan(&c); err != nil {
115121
return fmt.Errorf("failed to get singers count: %v", err)
116122
}

0 commit comments

Comments
 (0)