Skip to content

Commit e285310

Browse files
authored
fix: include mutations from original attempt in retry (#463)
Mutations that had been buffered during a transaction would not be included in internal retries if the original transaction was aborted.
1 parent e001cd1 commit e285310

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

aborted_transactions_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,51 @@ func TestCommitAborted(t *testing.T) {
6060
}
6161
}
6262

63+
func TestCommitWithMutationsAborted(t *testing.T) {
64+
t.Parallel()
65+
66+
db, server, teardown := setupTestDBConnectionWithParams(t, "minSessions=1;maxSessions=1")
67+
defer teardown()
68+
ctx := context.Background()
69+
70+
conn, err := db.Conn(ctx)
71+
if err != nil {
72+
t.Fatalf("failed to open connection: %v", err)
73+
}
74+
defer func() { _ = conn.Close() }()
75+
76+
tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
77+
if err != nil {
78+
t.Fatalf("begin failed: %v", err)
79+
}
80+
if err := conn.Raw(func(driverConn interface{}) error {
81+
spannerConn, _ := driverConn.(SpannerConn)
82+
mutation := spanner.Insert("foo", []string{}, []interface{}{})
83+
return spannerConn.BufferWrite([]*spanner.Mutation{mutation})
84+
}); err != nil {
85+
t.Fatalf("failed to buffer mutations: %v", err)
86+
}
87+
// Abort the transaction on the first commit attempt.
88+
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
89+
Errors: []error{status.Error(codes.Aborted, "Aborted")},
90+
})
91+
err = tx.Commit()
92+
if err != nil {
93+
t.Fatalf("commit failed: %v", err)
94+
}
95+
reqs := drainRequestsFromServer(server.TestSpanner)
96+
commitReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.CommitRequest{}))
97+
if g, w := len(commitReqs), 2; g != w {
98+
t.Fatalf("commit request count mismatch\nGot: %v\nWant: %v", g, w)
99+
}
100+
for _, req := range commitReqs {
101+
commitReq := req.(*sppb.CommitRequest)
102+
if g, w := len(commitReq.Mutations), 1; g != w {
103+
t.Fatalf("mutation count mismatch\n Got: %v\nWant: %v", g, w)
104+
}
105+
}
106+
}
107+
63108
func TestCommitAbortedWithInternalRetriesDisabled(t *testing.T) {
64109
t.Parallel()
65110

transaction.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ type readWriteTransaction struct {
233233
// transaction so far. These statements will be replayed on a new read write
234234
// transaction if the initial attempt is aborted.
235235
statements []retriableStatement
236+
237+
// mutations contains the buffered mutations of this transaction. These are
238+
// added to the next transaction if the transaction executes an internal retry.
239+
mutations []*spanner.Mutation
236240
}
237241

238242
// retriableStatement is the interface that is used to keep track of statements
@@ -364,6 +368,10 @@ func (tx *readWriteTransaction) retry(ctx context.Context) (err error) {
364368
tx.logger.Log(ctx, LevelNotice, "failed to reset transaction")
365369
return err
366370
}
371+
// Re-apply the mutations from the previous transaction.
372+
if err := tx.rwTx.BufferWrite(tx.mutations); err != nil {
373+
return err
374+
}
367375
for _, stmt := range tx.statements {
368376
tx.logger.Log(ctx, slog.LevelDebug, "retrying statement", "stmt", stmt)
369377
err = stmt.retry(ctx, tx.rwTx)
@@ -600,6 +608,7 @@ func (tx *readWriteTransaction) runDmlBatch(ctx context.Context) (*result, error
600608
}
601609

602610
func (tx *readWriteTransaction) BufferWrite(ms []*spanner.Mutation) error {
611+
tx.mutations = append(tx.mutations, ms...)
603612
return tx.rwTx.BufferWrite(ms)
604613
}
605614

0 commit comments

Comments
 (0)