@@ -2755,6 +2755,133 @@ func TestAutocommitBatchDml(t *testing.T) {
27552755 }
27562756}
27572757
2758+ func TestExecuteBatchDml (t * testing.T ) {
2759+ t .Parallel ()
2760+
2761+ ctx := context .Background ()
2762+ db , server , teardown := setupTestDBConnection (t )
2763+ defer teardown ()
2764+
2765+ _ = server .TestSpanner .PutStatementResult ("INSERT INTO Foo (Id, Val) VALUES (1, 'One')" , & testutil.StatementResult {
2766+ Type : testutil .StatementResultUpdateCount ,
2767+ UpdateCount : 1 ,
2768+ })
2769+ _ = server .TestSpanner .PutStatementResult ("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')" , & testutil.StatementResult {
2770+ Type : testutil .StatementResultUpdateCount ,
2771+ UpdateCount : 1 ,
2772+ })
2773+
2774+ res , err := ExecuteBatchDml (ctx , db , func (ctx context.Context , batch DmlBatch ) error {
2775+ if err := batch .ExecContext (ctx , "INSERT INTO Foo (Id, Val) VALUES (1, 'One')" ); err != nil {
2776+ return err
2777+ }
2778+ if err := batch .ExecContext (ctx , "INSERT INTO Foo (Id, Val) VALUES (2, 'Two')" ); err != nil {
2779+ return err
2780+ }
2781+ return nil
2782+ })
2783+ if err != nil {
2784+ t .Fatalf ("failed to execute dml batch: %v" , err )
2785+ }
2786+ affected , err := res .RowsAffected ()
2787+ if err != nil {
2788+ t .Fatalf ("could not get rows affected from batch: %v" , err )
2789+ }
2790+ if g , w := affected , int64 (2 ); g != w {
2791+ t .Fatalf ("affected rows mismatch\n Got: %v\n Want: %v" , g , w )
2792+ }
2793+ batchAffected , err := res .BatchRowsAffected ()
2794+ if err != nil {
2795+ t .Fatalf ("could not get batch rows affected from batch: %v" , err )
2796+ }
2797+ if g , w := batchAffected , []int64 {1 , 1 }; ! cmp .Equal (g , w ) {
2798+ t .Fatalf ("affected batch rows mismatch\n Got: %v\n Want: %v" , g , w )
2799+ }
2800+
2801+ requests := drainRequestsFromServer (server .TestSpanner )
2802+ // There should be no ExecuteSqlRequests on the server.
2803+ sqlRequests := requestsOfType (requests , reflect .TypeOf (& sppb.ExecuteSqlRequest {}))
2804+ if g , w := len (sqlRequests ), 0 ; g != w {
2805+ t .Fatalf ("sql requests count mismatch\n Got: %v\n Want: %v" , g , w )
2806+ }
2807+ batchRequests := requestsOfType (requests , reflect .TypeOf (& sppb.ExecuteBatchDmlRequest {}))
2808+ if g , w := len (batchRequests ), 1 ; g != w {
2809+ t .Fatalf ("BatchDML requests count mismatch\n Got: %v\n Want: %v" , g , w )
2810+ }
2811+ if ! batchRequests [0 ].(* sppb.ExecuteBatchDmlRequest ).LastStatements {
2812+ t .Fatal ("last statements flag not set" )
2813+ }
2814+ // The transaction should have been committed.
2815+ commitRequests := requestsOfType (requests , reflect .TypeOf (& sppb.CommitRequest {}))
2816+ if g , w := len (commitRequests ), 1 ; g != w {
2817+ t .Fatalf ("Commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
2818+ }
2819+ }
2820+
2821+ func TestExecuteBatchDmlError (t * testing.T ) {
2822+ t .Parallel ()
2823+
2824+ ctx := context .Background ()
2825+ db , server , teardown := setupTestDBConnection (t )
2826+ defer teardown ()
2827+
2828+ _ = server .TestSpanner .PutStatementResult ("INSERT INTO Foo (Id, Val) VALUES (1, 'One')" , & testutil.StatementResult {
2829+ Type : testutil .StatementResultUpdateCount ,
2830+ UpdateCount : 1 ,
2831+ })
2832+ c , err := db .Conn (ctx )
2833+ defer func () { _ = c .Close () }()
2834+ if err != nil {
2835+ t .Fatalf ("failed to obtain connection: %v" , err )
2836+ }
2837+
2838+ _ , err = ExecuteBatchDmlOnConn (ctx , c , func (ctx context.Context , batch DmlBatch ) error {
2839+ if err := batch .ExecContext (ctx , "INSERT INTO Foo (Id, Val) VALUES (1, 'One')" ); err != nil {
2840+ return err
2841+ }
2842+ return fmt .Errorf ("test error" )
2843+ })
2844+ if err == nil {
2845+ t .Fatalf ("failed to execute dml batch: %v" , err )
2846+ }
2847+
2848+ requests := drainRequestsFromServer (server .TestSpanner )
2849+ // There should be no requests on the server.
2850+ batchRequests := requestsOfType (requests , reflect .TypeOf (& sppb.ExecuteBatchDmlRequest {}))
2851+ if g , w := len (batchRequests ), 0 ; g != w {
2852+ t .Fatalf ("BatchDML requests count mismatch\n Got: %v\n Want: %v" , g , w )
2853+ }
2854+ commitRequests := requestsOfType (requests , reflect .TypeOf (& sppb.CommitRequest {}))
2855+ if g , w := len (commitRequests ), 0 ; g != w {
2856+ t .Fatalf ("Commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
2857+ }
2858+
2859+ // Verify that the connection is not in a batch, and that it can be used for other statements.
2860+ if err := c .Raw (func (driverConn any ) error {
2861+ spannerConn , ok := driverConn .(SpannerConn )
2862+ if ! ok {
2863+ return fmt .Errorf ("driver connection is not a SpannerConn" )
2864+ }
2865+ if spannerConn .InDMLBatch () {
2866+ return fmt .Errorf ("connection is still in a batch" )
2867+ }
2868+ return nil
2869+ }); err != nil {
2870+ t .Fatalf ("check if connection is in a batch failed: %v" , err )
2871+ }
2872+ res , err := c .ExecContext (ctx , `INSERT INTO Foo (Id, Val) VALUES (1, 'One')` )
2873+ if err != nil {
2874+ t .Fatalf ("failed to execute dml statement: %v" , err )
2875+ }
2876+ if affected , err := res .RowsAffected (); err != nil {
2877+ t .Fatalf ("failed to obtain rows affected: %v" , err )
2878+ } else {
2879+ if g , w := affected , int64 (1 ); g != w {
2880+ t .Fatalf ("affected rows mismatch\n Got: %v\n Want: %v" , g , w )
2881+ }
2882+ }
2883+ }
2884+
27582885func TestTransactionBatchDml (t * testing.T ) {
27592886 t .Parallel ()
27602887
@@ -2872,6 +2999,80 @@ func TestTransactionBatchDml(t *testing.T) {
28722999 }
28733000}
28743001
3002+ func TestExecuteBatchDmlTransaction (t * testing.T ) {
3003+ t .Parallel ()
3004+
3005+ ctx := context .Background ()
3006+ db , server , teardown := setupTestDBConnection (t )
3007+ defer teardown ()
3008+
3009+ _ = server .TestSpanner .PutStatementResult ("INSERT INTO Foo (Id, Val) VALUES (1, 'One')" , & testutil.StatementResult {
3010+ Type : testutil .StatementResultUpdateCount ,
3011+ UpdateCount : 1 ,
3012+ })
3013+ _ = server .TestSpanner .PutStatementResult ("INSERT INTO Foo (Id, Val) VALUES (2, 'Two')" , & testutil.StatementResult {
3014+ Type : testutil .StatementResultUpdateCount ,
3015+ UpdateCount : 1 ,
3016+ })
3017+
3018+ conn , err := db .Conn (ctx )
3019+ if err != nil {
3020+ t .Fatalf ("failed to obtain connection: %v" , err )
3021+ }
3022+ tx , err := conn .BeginTx (ctx , & sql.TxOptions {})
3023+ if err != nil {
3024+ t .Fatalf ("failed to begin transaction: %v" , err )
3025+ }
3026+ res , err := ExecuteBatchDmlOnConn (ctx , conn , func (ctx context.Context , batch DmlBatch ) error {
3027+ if err := batch .ExecContext (ctx , "INSERT INTO Foo (Id, Val) VALUES (1, 'One')" ); err != nil {
3028+ return err
3029+ }
3030+ if err := batch .ExecContext (ctx , "INSERT INTO Foo (Id, Val) VALUES (2, 'Two')" ); err != nil {
3031+ return err
3032+ }
3033+ return nil
3034+ })
3035+ if err != nil {
3036+ t .Fatalf ("failed to execute dml batch: %v" , err )
3037+ }
3038+ affected , err := res .RowsAffected ()
3039+ if err != nil {
3040+ t .Fatalf ("could not get rows affected from batch: %v" , err )
3041+ }
3042+ if g , w := affected , int64 (2 ); g != w {
3043+ t .Fatalf ("affected rows mismatch\n Got: %v\n Want: %v" , g , w )
3044+ }
3045+ batchAffected , err := res .BatchRowsAffected ()
3046+ if err != nil {
3047+ t .Fatalf ("could not get batch rows affected from batch: %v" , err )
3048+ }
3049+ if g , w := batchAffected , []int64 {1 , 1 }; ! cmp .Equal (g , w ) {
3050+ t .Fatalf ("affected batch rows mismatch\n Got: %v\n Want: %v" , g , w )
3051+ }
3052+
3053+ if err := tx .Commit (); err != nil {
3054+ t .Fatalf ("failed to commit transaction after batch: %v" , err )
3055+ }
3056+
3057+ requests := drainRequestsFromServer (server .TestSpanner )
3058+ // There should be no ExecuteSqlRequests on the server.
3059+ sqlRequests := requestsOfType (requests , reflect .TypeOf (& sppb.ExecuteSqlRequest {}))
3060+ if g , w := len (sqlRequests ), 0 ; g != w {
3061+ t .Fatalf ("sql requests count mismatch\n Got: %v\n Want: %v" , g , w )
3062+ }
3063+ batchRequests := requestsOfType (requests , reflect .TypeOf (& sppb.ExecuteBatchDmlRequest {}))
3064+ if g , w := len (batchRequests ), 1 ; g != w {
3065+ t .Fatalf ("BatchDML requests count mismatch\n Got: %v\n Want: %v" , g , w )
3066+ }
3067+ if batchRequests [0 ].(* sppb.ExecuteBatchDmlRequest ).LastStatements {
3068+ t .Fatal ("last statements flag was set, this should not happen for batches in a transaction" )
3069+ }
3070+ commitRequests := requestsOfType (requests , reflect .TypeOf (& sppb.CommitRequest {}))
3071+ if g , w := len (commitRequests ), 1 ; g != w {
3072+ t .Fatalf ("Commit requests count mismatch\n Got: %v\n Want: %v" , g , w )
3073+ }
3074+ }
3075+
28753076func TestCommitTimestamp (t * testing.T ) {
28763077 t .Parallel ()
28773078
0 commit comments