@@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase {
4242 taskGroup. cancelAll ( )
4343 }
4444 }
45+
46+ func testTransaction( ) async throws {
47+ var mlogger = Logger ( label: " test " )
48+ mlogger. logLevel = . debug
49+ let logger = mlogger
50+ let eventLoopGroup = MultiThreadedEventLoopGroup ( numberOfThreads: 8 )
51+ self . addTeardownBlock {
52+ try await eventLoopGroup. shutdownGracefully ( )
53+ }
54+
55+ let tableName = " test_client_transactions "
56+
57+ let clientConfig = PostgresClient . Configuration. makeTestConfiguration ( )
58+ let client = PostgresClient ( configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
59+
60+ do {
61+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
62+ taskGroup. addTask {
63+ await client. run ( )
64+ }
65+
66+ try await client. query (
67+ """
68+ CREATE TABLE IF NOT EXISTS " \( unescaped: tableName) " (
69+ id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
70+ uuid UUID NOT NULL
71+ );
72+ """ ,
73+ logger: logger
74+ )
75+
76+ let iterations = 1000
77+
78+ for _ in 0 ..< iterations {
79+ taskGroup. addTask {
80+ let _ = try await client. withTransaction { transaction in
81+ try await transaction. query (
82+ """
83+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( UUID ( ) ) );
84+ """ ,
85+ logger: logger
86+ )
87+ }
88+ }
89+ }
90+
91+ for _ in 0 ..< iterations {
92+ _ = await taskGroup. nextResult ( ) !
93+ }
94+
95+ let rows = try await client. query ( #"SELECT COUNT(1)::INT AS table_size FROM " \#( unescaped: tableName) ";"# , logger: logger) . decode ( Int . self)
96+ for try await (count) in rows {
97+ XCTAssertEqual ( count, iterations)
98+ }
99+
100+ /// Test roll back
101+ taskGroup. addTask {
102+
103+ do {
104+ let _ = try await client. withTransaction { transaction in
105+ /// insert valid data
106+ try await transaction. query (
107+ """
108+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( UUID ( ) ) );
109+ """ ,
110+ logger: logger
111+ )
112+
113+ /// insert invalid data
114+ try await transaction. query (
115+ """
116+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( iterations) );
117+ """ ,
118+ logger: logger
119+ )
120+ }
121+ } catch {
122+ XCTAssertNotNil ( error)
123+ guard let error = error as? PSQLError else { return XCTFail ( " Unexpected error type " ) }
124+
125+ XCTAssertEqual ( error. code, . server)
126+ XCTAssertEqual ( error. serverInfo ? [ . severity] , " ERROR " )
127+ }
128+ }
129+
130+ let row = try await client. query ( #"SELECT COUNT(1)::INT AS table_size FROM " \#( unescaped: tableName) ";"# , logger: logger) . decode ( Int . self)
131+
132+ for try await (count) in row {
133+ XCTAssertEqual ( count, iterations)
134+ }
135+
136+ try await client. query (
137+ """
138+ DROP TABLE " \( unescaped: tableName) " ;
139+ """ ,
140+ logger: logger
141+ )
142+
143+ taskGroup. cancelAll ( )
144+ }
145+ } catch {
146+ XCTFail ( " Unexpected error: \( String ( reflecting: error) ) " )
147+ }
148+ }
45149
46150 func testApplicationNameIsForwardedCorrectly( ) async throws {
47151 var mlogger = Logger ( label: " test " )
0 commit comments