@@ -1501,6 +1501,7 @@ func TestGenericConnectionState_GoogleSQL(t *testing.T) {
15011501 t.Fatal(err)
15021502 }
15031503}
1504+
15041505func TestGenericConnectionState_PostgreSQL(t *testing.T) {
15051506 t.Parallel()
15061507
@@ -1751,6 +1752,40 @@ func TestGenericConnectionState_PostgreSQL(t *testing.T) {
17511752 }
17521753}
17531754
1755+ func TestDmlBatchReturnsBatchUpdateCountsOutsideTransaction(t *testing.T) {
1756+ t.Parallel()
1757+ db, _, teardown := setupTestDBConnection(t)
1758+ defer teardown()
1759+ ctx := context.Background()
1760+
1761+ conn, err := db.Conn(ctx)
1762+ if err != nil {
1763+ t.Fatal(err)
1764+ }
1765+ defer silentClose(conn)
1766+
1767+ if _, err := conn.ExecContext(ctx, "start batch dml"); err != nil {
1768+ t.Fatal(err)
1769+ }
1770+ _, _ = conn.ExecContext(ctx, testutil.UpdateBarSetFoo)
1771+ _, _ = conn.ExecContext(ctx, testutil.UpdateSingersSetLastName)
1772+ var res SpannerResult
1773+ if err := conn.Raw(func(driverConn interface{}) error {
1774+ spannerConn, _ := driverConn.(SpannerConn)
1775+ res, err = spannerConn.RunDmlBatch(ctx)
1776+ return err
1777+ }); err != nil {
1778+ t.Fatal(err)
1779+ }
1780+ results, err := res.BatchRowsAffected()
1781+ if err != nil {
1782+ t.Fatal(err)
1783+ }
1784+ if g, w := results, []int64{testutil.UpdateBarSetFooRowCount, testutil.UpdateSingersSetLastNameRowCount}; !reflect.DeepEqual(g, w) {
1785+ t.Fatalf("batch affected mismatch\n Got: %v\nWant: %v", g, w)
1786+ }
1787+ }
1788+
17541789func verifyConnectionPropertyValue[T comparable](t *testing.T, c *sql.Conn, name string, value T) {
17551790 ctx := context.Background()
17561791 row := c.QueryRowContext(ctx, getShowStatement(c)+name)
0 commit comments