@@ -8,6 +8,8 @@ use diesel::QueryResult;
88use scoped_futures:: ScopedBoxFuture ;
99use std:: borrow:: Cow ;
1010use std:: num:: NonZeroU32 ;
11+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
12+ use std:: sync:: Arc ;
1113
1214use crate :: AsyncConnection ;
1315// TODO: refactor this to share more code with diesel
@@ -88,24 +90,31 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
8890 /// in an error state.
8991 #[ doc( hidden) ]
9092 fn is_broken_transaction_manager ( conn : & mut Conn ) -> bool {
91- match Self :: transaction_manager_status_mut ( conn) . transaction_state ( ) {
92- // all transactions are closed
93- // so we don't consider this connection broken
94- Ok ( ValidTransactionManagerStatus {
95- in_transaction : None ,
96- ..
97- } ) => false ,
98- // The transaction manager is in an error state
99- // Therefore we consider this connection broken
100- Err ( _) => true ,
101- // The transaction manager contains a open transaction
102- // we do consider this connection broken
103- // if that transaction was not opened by `begin_test_transaction`
104- Ok ( ValidTransactionManagerStatus {
105- in_transaction : Some ( s) ,
106- ..
107- } ) => !s. test_transaction ,
108- }
93+ check_broken_transaction_state ( conn)
94+ }
95+ }
96+
97+ fn check_broken_transaction_state < Conn > ( conn : & mut Conn ) -> bool
98+ where
99+ Conn : AsyncConnection ,
100+ {
101+ match Conn :: TransactionManager :: transaction_manager_status_mut ( conn) . transaction_state ( ) {
102+ // all transactions are closed
103+ // so we don't consider this connection broken
104+ Ok ( ValidTransactionManagerStatus {
105+ in_transaction : None ,
106+ ..
107+ } ) => false ,
108+ // The transaction manager is in an error state
109+ // Therefore we consider this connection broken
110+ Err ( _) => true ,
111+ // The transaction manager contains a open transaction
112+ // we do consider this connection broken
113+ // if that transaction was not opened by `begin_test_transaction`
114+ Ok ( ValidTransactionManagerStatus {
115+ in_transaction : Some ( s) ,
116+ ..
117+ } ) => !s. test_transaction ,
109118 }
110119}
111120
@@ -114,147 +123,23 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
114123#[ derive( Default , Debug ) ]
115124pub struct AnsiTransactionManager {
116125 pub ( crate ) status : TransactionManagerStatus ,
126+ // this boolean flag tracks whether we are currently in the process
127+ // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
128+ // if we ever encounter a situation where this flag is set
129+ // while the connection is returned to a pool
130+ // that means the connection is broken as someone dropped the
131+ // transaction future while these commands where executed
132+ // and we cannot know the connection state anymore
133+ //
134+ // We ensure this by wrapping all calls to `.await`
135+ // into `AnsiTransactionManager::critical_transaction_block`
136+ // below
137+ //
138+ // See https://github.com/weiznich/diesel_async/issues/198 for
139+ // details
140+ pub ( crate ) is_broken : Arc < AtomicBool > ,
117141}
118142
119- // /// Status of the transaction manager
120- // #[derive(Debug)]
121- // pub enum TransactionManagerStatus {
122- // /// Valid status, the manager can run operations
123- // Valid(ValidTransactionManagerStatus),
124- // /// Error status, probably following a broken connection. The manager will no longer run operations
125- // InError,
126- // }
127-
128- // impl Default for TransactionManagerStatus {
129- // fn default() -> Self {
130- // TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
131- // }
132- // }
133-
134- // impl TransactionManagerStatus {
135- // /// Returns the transaction depth if the transaction manager's status is valid, or returns
136- // /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
137- // pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
138- // match self {
139- // TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
140- // TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
141- // }
142- // }
143-
144- // /// If in transaction and transaction manager is not broken, registers that the
145- // /// connection can not be used anymore until top-level transaction is rolled back
146- // pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
147- // if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
148- // in_transaction:
149- // Some(InTransactionStatus {
150- // top_level_transaction_requires_rollback,
151- // ..
152- // }),
153- // }) = self
154- // {
155- // *top_level_transaction_requires_rollback = true;
156- // }
157- // }
158-
159- // /// Sets the transaction manager status to InError
160- // ///
161- // /// Subsequent attempts to use transaction-related features will result in a
162- // /// [`Error::BrokenTransactionManager`] error
163- // pub fn set_in_error(&mut self) {
164- // *self = TransactionManagerStatus::InError
165- // }
166-
167- // fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
168- // match self {
169- // TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
170- // TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
171- // }
172- // }
173-
174- // pub(crate) fn set_test_transaction_flag(&mut self) {
175- // if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
176- // in_transaction: Some(s),
177- // }) = self
178- // {
179- // s.test_transaction = true;
180- // }
181- // }
182- // }
183-
184- // /// Valid transaction status for the manager. Can return the current transaction depth
185- // #[allow(missing_copy_implementations)]
186- // #[derive(Debug, Default)]
187- // pub struct ValidTransactionManagerStatus {
188- // in_transaction: Option<InTransactionStatus>,
189- // }
190-
191- // #[allow(missing_copy_implementations)]
192- // #[derive(Debug)]
193- // struct InTransactionStatus {
194- // transaction_depth: NonZeroU32,
195- // top_level_transaction_requires_rollback: bool,
196- // test_transaction: bool,
197- // }
198-
199- // impl ValidTransactionManagerStatus {
200- // /// Return the current transaction depth
201- // ///
202- // /// This value is `None` if no current transaction is running
203- // /// otherwise the number of nested transactions is returned.
204- // pub fn transaction_depth(&self) -> Option<NonZeroU32> {
205- // self.in_transaction.as_ref().map(|it| it.transaction_depth)
206- // }
207-
208- // /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
209- // /// `Ok(())`
210- // pub fn change_transaction_depth(
211- // &mut self,
212- // transaction_depth_change: TransactionDepthChange,
213- // ) -> QueryResult<()> {
214- // match (&mut self.in_transaction, transaction_depth_change) {
215- // (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
216- // // Can be replaced with saturating_add directly on NonZeroU32 once
217- // // <https://github.com/rust-lang/rust/issues/84186> is stable
218- // in_transaction.transaction_depth =
219- // NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
220- // .expect("nz + nz is always non-zero");
221- // Ok(())
222- // }
223- // (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
224- // // This sets `transaction_depth` to `None` as soon as we reach zero
225- // match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
226- // Some(depth) => in_transaction.transaction_depth = depth,
227- // None => self.in_transaction = None,
228- // }
229- // Ok(())
230- // }
231- // (None, TransactionDepthChange::IncreaseDepth) => {
232- // self.in_transaction = Some(InTransactionStatus {
233- // transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
234- // top_level_transaction_requires_rollback: false,
235- // test_transaction: false,
236- // });
237- // Ok(())
238- // }
239- // (None, TransactionDepthChange::DecreaseDepth) => {
240- // // We screwed up something somewhere
241- // // we cannot decrease the transaction count if
242- // // we are not inside a transaction
243- // Err(Error::NotInTransaction)
244- // }
245- // }
246- // }
247- // }
248-
249- // /// Represents a change to apply to the depth of a transaction
250- // #[derive(Debug, Clone, Copy)]
251- // pub enum TransactionDepthChange {
252- // /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
253- // IncreaseDepth,
254- // /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
255- // DecreaseDepth,
256- // }
257-
258143impl AnsiTransactionManager {
259144 fn get_transaction_state < Conn > (
260145 conn : & mut Conn ,
@@ -274,17 +159,34 @@ impl AnsiTransactionManager {
274159 where
275160 Conn : AsyncConnection < TransactionManager = Self > ,
276161 {
162+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
277163 let state = Self :: get_transaction_state ( conn) ?;
278164 match state. transaction_depth ( ) {
279165 None => {
280- conn. batch_execute ( sql) . await ?;
166+ Self :: critical_transaction_block ( & is_broken , conn. batch_execute ( sql) ) . await ?;
281167 Self :: get_transaction_state ( conn) ?
282168 . change_transaction_depth ( TransactionDepthChange :: IncreaseDepth ) ?;
283169 Ok ( ( ) )
284170 }
285171 Some ( _depth) => Err ( Error :: AlreadyInTransaction ) ,
286172 }
287173 }
174+
175+ // This function should be used to await any connection
176+ // related future in our transaction manager implementation
177+ //
178+ // It takes care of tracking entering and exiting executing the future
179+ // which in turn is used to determine if it's safe to still use
180+ // the connection in the event of a canceled transaction execution
181+ async fn critical_transaction_block < F > ( is_broken : & AtomicBool , f : F ) -> F :: Output
182+ where
183+ F : std:: future:: Future ,
184+ {
185+ is_broken. store ( true , Ordering :: Relaxed ) ;
186+ let res = f. await ;
187+ is_broken. store ( false , Ordering :: Relaxed ) ;
188+ res
189+ }
288190}
289191
290192#[ async_trait:: async_trait]
@@ -308,7 +210,11 @@ where
308210 . unwrap_or ( NonZeroU32 :: new ( 1 ) . expect ( "It's not 0" ) ) ;
309211 conn. instrumentation ( )
310212 . on_connection_event ( InstrumentationEvent :: begin_transaction ( depth) ) ;
311- conn. batch_execute ( & start_transaction_sql) . await ?;
213+ Self :: critical_transaction_block (
214+ & conn. transaction_state ( ) . is_broken . clone ( ) ,
215+ conn. batch_execute ( & start_transaction_sql) ,
216+ )
217+ . await ?;
312218 Self :: get_transaction_state ( conn) ?
313219 . change_transaction_depth ( TransactionDepthChange :: IncreaseDepth ) ?;
314220
@@ -344,7 +250,10 @@ where
344250 conn. instrumentation ( )
345251 . on_connection_event ( InstrumentationEvent :: rollback_transaction ( depth) ) ;
346252
347- match conn. batch_execute ( & rollback_sql) . await {
253+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
254+
255+ match Self :: critical_transaction_block ( & is_broken, conn. batch_execute ( & rollback_sql) ) . await
256+ {
348257 Ok ( ( ) ) => {
349258 match Self :: get_transaction_state ( conn) ?
350259 . change_transaction_depth ( TransactionDepthChange :: DecreaseDepth )
@@ -429,7 +338,9 @@ where
429338 conn. instrumentation ( )
430339 . on_connection_event ( InstrumentationEvent :: commit_transaction ( depth) ) ;
431340
432- match conn. batch_execute ( & commit_sql) . await {
341+ let is_broken = conn. transaction_state ( ) . is_broken . clone ( ) ;
342+
343+ match Self :: critical_transaction_block ( & is_broken, conn. batch_execute ( & commit_sql) ) . await {
433344 Ok ( ( ) ) => {
434345 match Self :: get_transaction_state ( conn) ?
435346 . change_transaction_depth ( TransactionDepthChange :: DecreaseDepth )
@@ -453,7 +364,12 @@ where
453364 ..
454365 } ) = conn. transaction_state ( ) . status
455366 {
456- match Self :: rollback_transaction ( conn) . await {
367+ match Self :: critical_transaction_block (
368+ & is_broken,
369+ Self :: rollback_transaction ( conn) ,
370+ )
371+ . await
372+ {
457373 Ok ( ( ) ) => { }
458374 Err ( rollback_error) => {
459375 conn. transaction_state ( ) . status . set_in_error ( ) ;
@@ -472,4 +388,9 @@ where
472388 fn transaction_manager_status_mut ( conn : & mut Conn ) -> & mut TransactionManagerStatus {
473389 & mut conn. transaction_state ( ) . status
474390 }
391+
392+ fn is_broken_transaction_manager ( conn : & mut Conn ) -> bool {
393+ conn. transaction_state ( ) . is_broken . load ( Ordering :: Relaxed )
394+ || check_broken_transaction_state ( conn)
395+ }
475396}
0 commit comments