Skip to content

Commit 72f2170

Browse files
committed
Update PSQLRowStream.swift
1 parent 9cc2bd7 commit 72f2170

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,17 @@ final class PSQLRowStream: @unchecked Sendable {
3232
case empty
3333
case failure(Error)
3434
}
35-
35+
36+
private enum Consumed {
37+
case tag(String)
38+
case emptyResponse
39+
}
40+
3641
private enum DownstreamState {
3742
case waitingForConsumer(BufferState)
3843
case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
3944
case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource)
40-
case consumed(Result<String, Error>)
41-
case finished
45+
case consumed(Result<Consumed, Error>)
4246
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ())
4347
}
4448

@@ -108,13 +112,13 @@ final class PSQLRowStream: @unchecked Sendable {
108112
case .empty:
109113
source.finish()
110114
onFinish()
111-
self.downstreamState = .finished
115+
self.downstreamState = .consumed(.success(.emptyResponse))
112116

113117
case .finished(let buffer, let commandTag):
114118
_ = source.yield(contentsOf: buffer)
115119
source.finish()
116120
onFinish()
117-
self.downstreamState = .consumed(.success(commandTag))
121+
self.downstreamState = .consumed(.success(.tag(commandTag)))
118122

119123
case .failure(let error):
120124
source.finish(error)
@@ -139,7 +143,7 @@ final class PSQLRowStream: @unchecked Sendable {
139143
case .waitingForConsumer, .iteratingRows, .waitingForAll:
140144
preconditionFailure("Invalid state: \(self.downstreamState)")
141145

142-
case .consumed, .finished:
146+
case .consumed:
143147
break
144148

145149
case .asyncSequence(_, let dataSource, _):
@@ -164,7 +168,7 @@ final class PSQLRowStream: @unchecked Sendable {
164168
dataSource.cancel(for: self)
165169
onFinish()
166170

167-
case .consumed, .finished:
171+
case .consumed:
168172
return
169173

170174
case .waitingForConsumer, .iteratingRows, .waitingForAll:
@@ -207,15 +211,15 @@ final class PSQLRowStream: @unchecked Sendable {
207211
PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription)
208212
}
209213

210-
self.downstreamState = .consumed(.success(commandTag))
214+
self.downstreamState = .consumed(.success(.tag(commandTag)))
211215
return self.eventLoop.makeSucceededFuture(rows)
212216

213217
case .failure(let error):
214218
self.downstreamState = .consumed(.failure(error))
215219
return self.eventLoop.makeFailedFuture(error)
216220

217221
case .empty:
218-
self.downstreamState = .finished
222+
self.downstreamState = .consumed(.success(.emptyResponse))
219223
return self.eventLoop.makeSucceededFuture([])
220224
}
221225
}
@@ -265,7 +269,7 @@ final class PSQLRowStream: @unchecked Sendable {
265269
return promise.futureResult
266270

267271
case .empty:
268-
self.downstreamState = .finished
272+
self.downstreamState = .consumed(.success(.emptyResponse))
269273
return self.eventLoop.makeSucceededVoidFuture()
270274

271275
case .finished(let buffer, let commandTag):
@@ -279,7 +283,7 @@ final class PSQLRowStream: @unchecked Sendable {
279283
try onRow(row)
280284
}
281285

282-
self.downstreamState = .consumed(.success(commandTag))
286+
self.downstreamState = .consumed(.success(.tag(commandTag)))
283287
return self.eventLoop.makeSucceededVoidFuture()
284288
} catch {
285289
self.downstreamState = .consumed(.failure(error))
@@ -350,9 +354,6 @@ final class PSQLRowStream: @unchecked Sendable {
350354

351355
case .consumed(.failure):
352356
break
353-
354-
case .finished:
355-
preconditionFailure("How can we receive further rows, if we are supposed to be done")
356357
}
357358
}
358359

@@ -376,22 +377,22 @@ final class PSQLRowStream: @unchecked Sendable {
376377
preconditionFailure("How can we get another end, if an end was already signalled?")
377378

378379
case .iteratingRows(_, let promise, _):
379-
self.downstreamState = .consumed(.success(commandTag))
380+
self.downstreamState = .consumed(.success(.tag(commandTag)))
380381
promise.succeed(())
381382

382383
case .waitingForAll(let rows, let promise, _):
383-
self.downstreamState = .consumed(.success(commandTag))
384+
self.downstreamState = .consumed(.success(.tag(commandTag)))
384385
promise.succeed(rows)
385386

386387
case .asyncSequence(let source, _, let onFinish):
387-
self.downstreamState = .consumed(.success(commandTag))
388+
self.downstreamState = .consumed(.success(.tag(commandTag)))
388389
source.finish()
389390
onFinish()
390391

391-
case .consumed:
392+
case .consumed(.success(.tag)), .consumed(.failure):
392393
break
393394

394-
case .finished, .waitingForConsumer(.empty):
395+
case .consumed(.success(.emptyResponse)), .waitingForConsumer(.empty):
395396
preconditionFailure("How can we get an end for empty query response?")
396397
}
397398
}
@@ -417,10 +418,10 @@ final class PSQLRowStream: @unchecked Sendable {
417418
consumer.finish(error)
418419
onFinish()
419420

420-
case .consumed:
421+
case .consumed(.success(.tag)), .consumed(.failure):
421422
break
422423

423-
case .finished:
424+
case .consumed(.success(.emptyResponse)):
424425
preconditionFailure("How can we get an error for empty query response?")
425426
}
426427
}
@@ -442,13 +443,14 @@ final class PSQLRowStream: @unchecked Sendable {
442443
}
443444

444445
var commandTag: String {
445-
switch self.downstreamState {
446-
case .consumed(.success(let commandTag)):
447-
return commandTag
448-
case .finished:
446+
guard case .consumed(.success(let consumed)) = self.downstreamState else {
447+
preconditionFailure("commandTag may only be called if all rows have been consumed")
448+
}
449+
switch consumed {
450+
case .tag(let tag):
451+
return tag
452+
case .emptyResponse:
449453
return ""
450-
default:
451-
preconditionFailure("commandTag may only be called if there are no more rows to be consumed")
452454
}
453455
}
454456
}

0 commit comments

Comments
 (0)