@@ -21,7 +21,12 @@ import {EventEmitter} from 'events';
2121import * as assert from 'assert' ;
2222import * as extend from 'extend' ;
2323import { google } from '../protos/protos' ;
24- import { CommitCallback , CommitOptions , MutationSet } from '../src/transaction' ;
24+ import {
25+ BatchWriteOptions ,
26+ CommitCallback ,
27+ CommitOptions ,
28+ MutationSet ,
29+ } from '../src/transaction' ;
2530import { util } from '@google-cloud/common' ;
2631import { Transform } from 'stream' ;
2732import * as proxyquire from 'proxyquire' ;
@@ -35,7 +40,7 @@ const {
3540// eslint-disable-next-line n/no-extraneous-require
3641const { SimpleSpanProcessor} = require ( '@opentelemetry/sdk-trace-base' ) ;
3742import * as db from '../src/database' ;
38- import { Instance , Spanner } from '../src' ;
43+ import { Instance , MutationGroup , Spanner } from '../src' ;
3944import * as pfy from '@google-cloud/promisify' ;
4045import { grpc } from 'google-gax' ;
4146import { MockError } from '../test/mockserver/mockspanner' ;
@@ -1215,6 +1220,224 @@ describe('Database', () => {
12151220 } ) ;
12161221 } ) ;
12171222
1223+ describe ( 'batchWriteAtLeastOnce' , ( ) => {
1224+ const mutationGroup1 = new MutationGroup ( ) ;
1225+ mutationGroup1 . insert ( 'MyTable' , {
1226+ Key : 'ks1' ,
1227+ Thing : 'abc' ,
1228+ } ) ;
1229+ const mutationGroup2 = new MutationGroup ( ) ;
1230+ mutationGroup2 . insert ( 'MyTable' , {
1231+ Key : 'ks2' ,
1232+ Thing : 'xyz' ,
1233+ } ) ;
1234+
1235+ const mutationGroups = [ mutationGroup1 , mutationGroup2 ] ;
1236+
1237+ let fakePool : FakeSessionPool ;
1238+ let fakeSession : FakeSession ;
1239+ let fakeDataStream : Transform ;
1240+ let getSessionStub : sinon . SinonStub ;
1241+ let requestStreamStub : sinon . SinonStub ;
1242+
1243+ const options = {
1244+ requestOptions : {
1245+ transactionTag : 'batch-write-tag' ,
1246+ } ,
1247+ excludeTxnFromChangeStream : true ,
1248+ gaxOptions : { autoPaginate : false } ,
1249+ } as BatchWriteOptions ;
1250+
1251+ beforeEach ( ( ) => {
1252+ fakePool = database . pool_ ;
1253+ fakeSession = new FakeSession ( ) ;
1254+ fakeDataStream = through . obj ( ) ;
1255+
1256+ getSessionStub = (
1257+ sandbox . stub ( fakePool , 'getSession' ) as sinon . SinonStub
1258+ ) . callsFake ( callback => callback ( null , fakeSession ) ) ;
1259+
1260+ requestStreamStub = sandbox
1261+ . stub ( database , 'requestStream' )
1262+ . returns ( fakeDataStream ) ;
1263+ } ) ;
1264+
1265+ it ( 'on retry with "Session not found" error' , done => {
1266+ const sessionNotFoundError = {
1267+ code : grpc . status . NOT_FOUND ,
1268+ message : 'Session not found' ,
1269+ } as grpc . ServiceError ;
1270+ let retryCount = 0 ;
1271+
1272+ database
1273+ . batchWriteAtLeastOnce ( mutationGroups , options )
1274+ . on ( 'data' , ( ) => { } )
1275+ . on ( 'error' , err => {
1276+ assert . fail ( err ) ;
1277+ } )
1278+ . on ( 'end' , ( ) => {
1279+ assert . strictEqual ( retryCount , 1 ) ;
1280+
1281+ const spans = traceExporter . getFinishedSpans ( ) ;
1282+ withAllSpansHaveDBName ( spans ) ;
1283+
1284+ const actualSpanNames : string [ ] = [ ] ;
1285+ const actualEventNames : string [ ] = [ ] ;
1286+ spans . forEach ( span => {
1287+ actualSpanNames . push ( span . name ) ;
1288+ span . events . forEach ( event => {
1289+ actualEventNames . push ( event . name ) ;
1290+ } ) ;
1291+ } ) ;
1292+
1293+ const expectedSpanNames = [
1294+ 'CloudSpanner.Database.batchWriteAtLeastOnce' ,
1295+ 'CloudSpanner.Database.batchWriteAtLeastOnce' ,
1296+ ] ;
1297+ assert . deepStrictEqual (
1298+ actualSpanNames ,
1299+ expectedSpanNames ,
1300+ `span names mismatch:\n\tGot: ${ actualSpanNames } \n\tWant: ${ expectedSpanNames } `
1301+ ) ;
1302+
1303+ // Ensure that the span actually produced an error that was recorded.
1304+ const firstSpan = spans [ 0 ] ;
1305+ assert . strictEqual (
1306+ SpanStatusCode . ERROR ,
1307+ firstSpan . status . code ,
1308+ 'Expected an ERROR span status'
1309+ ) ;
1310+
1311+ const errorMessage = firstSpan . status . message ;
1312+ assert . deepStrictEqual (
1313+ firstSpan . status . message ,
1314+ sessionNotFoundError . message
1315+ ) ;
1316+
1317+ // The last span should not have an error status.
1318+ const lastSpan = spans [ spans . length - 1 ] ;
1319+ assert . strictEqual (
1320+ SpanStatusCode . UNSET ,
1321+ lastSpan . status . code ,
1322+ 'Unexpected span status'
1323+ ) ;
1324+
1325+ assert . deepStrictEqual ( lastSpan . status . message , undefined ) ;
1326+
1327+ const expectedEventNames = [
1328+ 'Using Session' ,
1329+ 'No session available' ,
1330+ 'Using Session' ,
1331+ ] ;
1332+ assert . deepStrictEqual ( actualEventNames , expectedEventNames ) ;
1333+
1334+ done ( ) ;
1335+ } ) ;
1336+
1337+ fakeDataStream . emit ( 'error' , sessionNotFoundError ) ;
1338+ retryCount ++ ;
1339+ } ) ;
1340+
1341+ it ( 'on getSession errors' , done => {
1342+ const fakeError = new Error ( 'err' ) ;
1343+
1344+ getSessionStub . callsFake ( callback => callback ( fakeError ) ) ;
1345+ database
1346+ . batchWriteAtLeastOnce ( mutationGroups , options )
1347+ . on ( 'error' , err => {
1348+ assert . strictEqual ( err , fakeError ) ;
1349+
1350+ const spans = traceExporter . getFinishedSpans ( ) ;
1351+ withAllSpansHaveDBName ( spans ) ;
1352+
1353+ const actualSpanNames : string [ ] = [ ] ;
1354+ const actualEventNames : string [ ] = [ ] ;
1355+ spans . forEach ( span => {
1356+ actualSpanNames . push ( span . name ) ;
1357+ span . events . forEach ( event => {
1358+ actualEventNames . push ( event . name ) ;
1359+ } ) ;
1360+ } ) ;
1361+
1362+ const expectedSpanNames = [
1363+ 'CloudSpanner.Database.batchWriteAtLeastOnce' ,
1364+ ] ;
1365+ assert . deepStrictEqual (
1366+ actualSpanNames ,
1367+ expectedSpanNames ,
1368+ `span names mismatch:\n\tGot: ${ actualSpanNames } \n\tWant: ${ expectedSpanNames } `
1369+ ) ;
1370+
1371+ // Ensure that the span actually produced an error that was recorded.
1372+ const firstSpan = spans [ 0 ] ;
1373+ assert . strictEqual (
1374+ SpanStatusCode . ERROR ,
1375+ firstSpan . status . code ,
1376+ 'Expected an ERROR span status'
1377+ ) ;
1378+
1379+ assert . deepStrictEqual ( firstSpan . status . message , fakeError . message ) ;
1380+
1381+ const expectedEventNames = [ ] ;
1382+ assert . deepStrictEqual ( expectedEventNames , actualEventNames ) ;
1383+
1384+ done ( ) ;
1385+ } ) ;
1386+ } ) ;
1387+
1388+ it ( 'with no errors' , done => {
1389+ getSessionStub . callsFake ( callback => callback ( null , { } ) ) ;
1390+ database
1391+ . batchWriteAtLeastOnce ( mutationGroups , options )
1392+ . on ( 'data' , ( ) => { } )
1393+ . on ( 'error' , assert . ifError )
1394+ . on ( 'end' , ( ) => {
1395+ const spans = traceExporter . getFinishedSpans ( ) ;
1396+ withAllSpansHaveDBName ( spans ) ;
1397+
1398+ const actualSpanNames : string [ ] = [ ] ;
1399+ const actualEventNames : string [ ] = [ ] ;
1400+ spans . forEach ( span => {
1401+ actualSpanNames . push ( span . name ) ;
1402+ span . events . forEach ( event => {
1403+ actualEventNames . push ( event . name ) ;
1404+ } ) ;
1405+ } ) ;
1406+
1407+ const expectedSpanNames = [
1408+ 'CloudSpanner.Database.batchWriteAtLeastOnce' ,
1409+ ] ;
1410+ assert . deepStrictEqual (
1411+ actualSpanNames ,
1412+ expectedSpanNames ,
1413+ `span names mismatch:\n\tGot: ${ actualSpanNames } \n\tWant: ${ expectedSpanNames } `
1414+ ) ;
1415+
1416+ // Ensure that the span actually produced an error that was recorded.
1417+ const firstSpan = spans [ 0 ] ;
1418+ assert . strictEqual (
1419+ SpanStatusCode . UNSET ,
1420+ firstSpan . status . code ,
1421+ 'Unexpected span status code'
1422+ ) ;
1423+
1424+ assert . strictEqual (
1425+ undefined ,
1426+ firstSpan . status . message ,
1427+ 'Unexpected span status message'
1428+ ) ;
1429+
1430+ const expectedEventNames = [ 'Using Session' ] ;
1431+ assert . deepStrictEqual ( actualEventNames , expectedEventNames ) ;
1432+
1433+ done ( ) ;
1434+ } ) ;
1435+
1436+ fakeDataStream . emit ( 'data' , 'response' ) ;
1437+ fakeDataStream . end ( 'end' ) ;
1438+ } ) ;
1439+ } ) ;
1440+
12181441 describe ( 'runTransaction' , ( ) => {
12191442 const SESSION = new FakeSession ( ) ;
12201443 const TRANSACTION = new FakeTransaction (
0 commit comments