Skip to content

Commit 6f8c1e2

Browse files
authored
fix: return batch update counts also when retries are disabled (#566)
The RunDmlBatch function of SpannerConn returns a Spanner-specific result type that also contains the actual update counts per statement in a batch, and not only the total update count. This was however not correctly returned if a DML batch was executed in a read/write transaction with retries disabled.
1 parent 0566936 commit 6f8c1e2

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

conn_with_mockserver_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,7 @@ func TestGenericConnectionState_GoogleSQL(t *testing.T) {
15011501
t.Fatal(err)
15021502
}
15031503
}
1504+
15041505
func TestGenericConnectionState_PostgreSQL(t *testing.T) {
15051506
t.Parallel()
15061507

@@ -1751,6 +1752,40 @@ func TestGenericConnectionState_PostgreSQL(t *testing.T) {
17511752
}
17521753
}
17531754

1755+
func TestDmlBatchReturnsBatchUpdateCountsOutsideTransaction(t *testing.T) {
1756+
t.Parallel()
1757+
db, _, teardown := setupTestDBConnection(t)
1758+
defer teardown()
1759+
ctx := context.Background()
1760+
1761+
conn, err := db.Conn(ctx)
1762+
if err != nil {
1763+
t.Fatal(err)
1764+
}
1765+
defer silentClose(conn)
1766+
1767+
if _, err := conn.ExecContext(ctx, "start batch dml"); err != nil {
1768+
t.Fatal(err)
1769+
}
1770+
_, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo)
1771+
_, _ = conn.ExecContext(ctx, testutil.UpdateSingersSetLastName)
1772+
var res SpannerResult
1773+
if err := conn.Raw(func(driverConn interface{}) error {
1774+
spannerConn, _ := driverConn.(SpannerConn)
1775+
res, err = spannerConn.RunDmlBatch(ctx)
1776+
return err
1777+
}); err != nil {
1778+
t.Fatal(err)
1779+
}
1780+
results, err := res.BatchRowsAffected()
1781+
if err != nil {
1782+
t.Fatal(err)
1783+
}
1784+
if g, w := results, []int64{testutil.UpdateBarSetFooRowCount, testutil.UpdateSingersSetLastNameRowCount}; !reflect.DeepEqual(g, w) {
1785+
t.Fatalf("batch affected mismatch\n Got: %v\nWant: %v", g, w)
1786+
}
1787+
}
1788+
17541789
func verifyConnectionPropertyValue[T comparable](t *testing.T, c *sql.Conn, name string, value T) {
17551790
ctx := context.Background()
17561791
row := c.QueryRowContext(ctx, getShowStatement(c)+name)

transaction.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ func (tx *readWriteTransaction) runDmlBatch(ctx context.Context) (*result, error
744744

745745
if !tx.retryAborts() {
746746
affected, err := tx.rwTx.BatchUpdateWithOptions(ctx, statements, options.QueryOptions)
747-
return &result{rowsAffected: sum(affected)}, err
747+
return &result{rowsAffected: sum(affected), batchUpdateCounts: affected}, err
748748
}
749749

750750
var affected []int64

transaction_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package spannerdriver
33
import (
44
"context"
55
"database/sql"
6+
"fmt"
67
"reflect"
78
"testing"
89

@@ -224,3 +225,49 @@ func TestBeginTransactionDeferrable(t *testing.T) {
224225
t.Fatalf("deferrable mismatch\n Got: %v\nWant: %v", g, w)
225226
}
226227
}
228+
229+
func TestDmlBatchReturnsBatchUpdateCounts(t *testing.T) {
230+
t.Parallel()
231+
db, _, teardown := setupTestDBConnection(t)
232+
defer teardown()
233+
ctx := context.Background()
234+
235+
conn, err := db.Conn(ctx)
236+
if err != nil {
237+
t.Fatal(err)
238+
}
239+
defer silentClose(conn)
240+
241+
for _, retry := range []bool{true, false} {
242+
_, err := conn.ExecContext(ctx, "begin transaction")
243+
if err != nil {
244+
t.Fatal(err)
245+
}
246+
if _, err := conn.ExecContext(ctx, fmt.Sprintf("set local retry_aborts_internally=%v", retry)); err != nil {
247+
t.Fatal(err)
248+
}
249+
if _, err := conn.ExecContext(ctx, "start batch dml"); err != nil {
250+
t.Fatal(err)
251+
}
252+
_, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo)
253+
_, _ = conn.ExecContext(ctx, testutil.UpdateSingersSetLastName)
254+
var res SpannerResult
255+
if err := conn.Raw(func(driverConn interface{}) error {
256+
spannerConn, _ := driverConn.(SpannerConn)
257+
res, err = spannerConn.RunDmlBatch(ctx)
258+
return err
259+
}); err != nil {
260+
t.Fatal(err)
261+
}
262+
results, err := res.BatchRowsAffected()
263+
if err != nil {
264+
t.Fatal(err)
265+
}
266+
if g, w := results, []int64{testutil.UpdateBarSetFooRowCount, testutil.UpdateSingersSetLastNameRowCount}; !reflect.DeepEqual(g, w) {
267+
t.Fatalf("batch affected mismatch\n Got: %v\nWant: %v", g, w)
268+
}
269+
if _, err := conn.ExecContext(ctx, "commit"); err != nil {
270+
t.Fatal(err)
271+
}
272+
}
273+
}

0 commit comments

Comments
 (0)