Skip to content

Commit 1ce8e38

Browse files
authored
Merge pull request groue#1747 from groue/dev/issue-1746
Transaction observers are not impacted by Task Cancellation
2 parents f2aa6b3 + 67b34dc commit 1ce8e38

File tree

4 files changed

+104
-12
lines changed

4 files changed

+104
-12
lines changed

GRDB/Core/Database+Statements.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,9 @@ extension Database {
494494

495495
switch ResultCode(rawValue: resultCode) {
496496
case .SQLITE_INTERRUPT, .SQLITE_ABORT:
497-
if suspensionMutex.load().isCancelled {
498-
// The only error that a user sees when a Task is cancelled
499-
// is CancellationError.
500-
throw CancellationError()
501-
}
497+
// The only error that a user sees when a Task is cancelled
498+
// is CancellationError.
499+
try suspensionMutex.load().checkCancellation()
502500
default:
503501
break
504502
}

GRDB/Core/Database.swift

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,22 @@ public final class Database: CustomStringConvertible, CustomDebugStringConvertib
341341

342342
/// If true, the database access has been cancelled.
343343
var isCancelled: Bool
344+
345+
/// If true, the database throws an error when it is cancelled.
346+
var interruptsWhenCancelled: Bool
347+
348+
func checkCancellation() throws {
349+
if isCancelled, interruptsWhenCancelled {
350+
throw CancellationError()
351+
}
352+
}
344353
}
345354

346355
/// Support for `checkForSuspensionViolation(from:)`
347-
let suspensionMutex = Mutex(Suspension(isSuspended: false, isCancelled: false))
356+
let suspensionMutex = Mutex(Suspension(
357+
isSuspended: false,
358+
isCancelled: false,
359+
interruptsWhenCancelled: true))
348360

349361
// MARK: - Transaction Date
350362

@@ -1222,7 +1234,7 @@ public final class Database: CustomStringConvertible, CustomDebugStringConvertib
12221234
}
12231235

12241236
suspension.isCancelled = true
1225-
return true
1237+
return suspension.interruptsWhenCancelled
12261238
}
12271239

12281240
if needsInterrupt {
@@ -1237,6 +1249,24 @@ public final class Database: CustomStringConvertible, CustomDebugStringConvertib
12371249
}
12381250
}
12391251

1252+
/// Within the given closure, Task cancellation does not interrupt
1253+
/// database accesses.
1254+
func ignoringCancellation<T>(_ value: () throws -> T) rethrows -> T {
1255+
let previous = suspensionMutex.withLock {
1256+
let previous = $0.interruptsWhenCancelled
1257+
$0.interruptsWhenCancelled = false
1258+
return previous
1259+
}
1260+
1261+
defer {
1262+
suspensionMutex.withLock {
1263+
$0.interruptsWhenCancelled = previous
1264+
}
1265+
}
1266+
1267+
return try value()
1268+
}
1269+
12401270
/// Support for `checkForSuspensionViolation(from:)`
12411271
private func journalMode() throws -> String {
12421272
if let journalMode = journalModeCache {
@@ -1303,7 +1333,7 @@ public final class Database: CustomStringConvertible, CustomDebugStringConvertib
13031333
let interrupt: Interrupt? = try suspensionMutex.withLock { suspension in
13041334
// Check for cancellation first, so that the only error that
13051335
// a user sees when a Task is cancelled is CancellationError.
1306-
if suspension.isCancelled {
1336+
if suspension.isCancelled, suspension.interruptsWhenCancelled {
13071337
return .cancel
13081338
}
13091339

GRDB/Core/TransactionObserver.swift

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,8 +549,12 @@ class DatabaseObservationBroker {
549549
savepointStack.clear()
550550

551551
if !database.isReadOnly {
552-
for observation in transactionObservations {
553-
observation.databaseDidCommit(database)
552+
// Observers must be able to access the database, even if the
553+
// task that has performed the commit is cancelled.
554+
database.ignoringCancellation {
555+
for observation in transactionObservations {
556+
observation.databaseDidCommit(database)
557+
}
554558
}
555559
}
556560

@@ -609,12 +613,18 @@ class DatabaseObservationBroker {
609613

610614
// Called from statementDidExecute or statementDidFail
611615
private func databaseDidRollback(notifyTransactionObservers: Bool) {
616+
#warning("TODO")
612617
savepointStack.clear()
613618

614619
if notifyTransactionObservers {
615620
assert(!database.isReadOnly, "Read-only transactions are not notified")
616-
for observation in transactionObservations {
617-
observation.databaseDidRollback(database)
621+
622+
// Observers must be able to access the database, even if the
623+
// task that has performed the commit is cancelled.
624+
database.ignoringCancellation {
625+
for observation in transactionObservations {
626+
observation.databaseDidRollback(database)
627+
}
618628
}
619629
}
620630
databaseDidEndTransaction()

Tests/GRDBTests/ValueObservationTests.swift

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,4 +1352,58 @@ class ValueObservationTests: GRDBTestCase {
13521352
XCTAssertEqual(value, true)
13531353
})
13541354
}
1355+
1356+
// Regression test for <https://github.com/groue/GRDB.swift/discussions/1746>
1357+
@MainActor func testIssue1746() async throws {
1358+
let dbQueue = try makeDatabaseQueue()
1359+
try await dbQueue.write { db in
1360+
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY)")
1361+
}
1362+
1363+
// Start a task that waits for writeContinuation before it writes.
1364+
let (writeStream, writeContinuation) = AsyncStream.makeStream(of: Void.self)
1365+
let task = Task {
1366+
for await _ in writeStream { }
1367+
try? await dbQueue.write { db in
1368+
XCTAssertFalse(Task.isCancelled) // Required for the test to be meaningful
1369+
try db.execute(sql: "INSERT INTO test DEFAULT VALUES")
1370+
}
1371+
}
1372+
1373+
// A transaction observer that cancels a Task after commit.
1374+
class CancelObserver: TransactionObserver {
1375+
let task: Task<Void, Never>
1376+
init(task: Task<Void, Never>) {
1377+
self.task = task
1378+
}
1379+
func observes(eventsOfKind eventKind: DatabaseEventKind) -> Bool { true }
1380+
func databaseDidChange(with event: DatabaseEvent) { }
1381+
func databaseDidRollback(_ db: Database) { }
1382+
func databaseDidCommit(_ db: Database) {
1383+
task.cancel()
1384+
}
1385+
}
1386+
1387+
// Register CancelObserver first, so no other observer could access
1388+
// the database before the task is cancelled.
1389+
dbQueue.add(transactionObserver: CancelObserver(task: task), extent: .databaseLifetime)
1390+
1391+
// Start observing.
1392+
// We expect to see 0, then 1.
1393+
let values = ValueObservation
1394+
.tracking(Table("test").fetchCount)
1395+
.values(in: dbQueue)
1396+
for try await value in values {
1397+
if value == 0 {
1398+
// Perform the write.
1399+
writeContinuation.finish()
1400+
} else if value == 1 {
1401+
// Test passes
1402+
break
1403+
} else {
1404+
XCTFail("Unexpected value \(value)")
1405+
break
1406+
}
1407+
}
1408+
}
13551409
}

0 commit comments

Comments
 (0)