66//!
77//! * using a sync Connection implementation in async context
88//! * using the same code base for async crates needing multiple backends
9+ use std:: error:: Error ;
10+ use futures_util:: future:: BoxFuture ;
911
1012#[ cfg( feature = "sqlite" ) ]
1113mod sqlite;
1214
15+ /// This is a helper trait that allows to customize the
16+ /// spawning blocking tasks as part of the
17+ /// [`SyncConnectionWrapper`] type. By default a
18+ /// tokio runtime and its spawn_blocking function is used.
19+ pub trait SpawnBlocking {
20+ /// This function should allow to execute a
21+ /// given blocking task without blocking the caller
22+ /// to get the result
23+ fn spawn_blocking < ' a , R > (
24+ & mut self ,
25+ task : impl FnOnce ( ) -> R + Send + ' static ,
26+ ) -> BoxFuture < ' a , Result < R , Box < dyn Error + Send + Sync + ' static > > >
27+ where
28+ R : Send + ' static ;
29+
30+ /// This function should be used to construct
31+ /// a new runtime instance
32+ fn get_runtime ( ) -> Self ;
33+ }
34+
35+ #[ cfg( feature = "tokio" ) ]
36+ pub type SyncConnectionWrapper < C , B = self :: implementation:: Tokio > = self :: implementation:: SyncConnectionWrapper < C , B > ;
37+
38+ #[ cfg( not( feature = "tokio" ) ) ]
1339pub use self :: implementation:: SyncConnectionWrapper ;
40+
1441pub use self :: implementation:: SyncTransactionManagerWrapper ;
1542
1643mod implementation {
@@ -25,17 +52,17 @@ mod implementation {
2552 } ;
2653 use diesel:: row:: IntoOwnedRow ;
2754 use diesel:: { ConnectionResult , QueryResult } ;
28- use futures_util:: future:: BoxFuture ;
2955 use futures_util:: stream:: BoxStream ;
3056 use futures_util:: { FutureExt , StreamExt , TryFutureExt } ;
3157 use std:: marker:: PhantomData ;
3258 use std:: sync:: { Arc , Mutex } ;
33- use tokio:: task:: JoinError ;
3459
35- fn from_tokio_join_error ( join_error : JoinError ) -> diesel:: result:: Error {
60+ use super :: * ;
61+
62+ fn from_spawn_blocking_error ( error : Box < dyn Error + Send + Sync + ' static > ) -> diesel:: result:: Error {
3663 diesel:: result:: Error :: DatabaseError (
3764 diesel:: result:: DatabaseErrorKind :: UnableToSendCommand ,
38- Box :: new ( join_error . to_string ( ) ) ,
65+ Box :: new ( error . to_string ( ) ) ,
3966 )
4067 }
4168
@@ -77,13 +104,15 @@ mod implementation {
77104 /// # some_async_fn().await;
78105 /// # }
79106 /// ```
80- pub struct SyncConnectionWrapper < C > {
107+ pub struct SyncConnectionWrapper < C , S > {
81108 inner : Arc < Mutex < C > > ,
109+ runtime : S ,
82110 }
83111
84- impl < C > SimpleAsyncConnection for SyncConnectionWrapper < C >
112+ impl < C , S > SimpleAsyncConnection for SyncConnectionWrapper < C , S >
85113 where
86114 C : diesel:: connection:: Connection + ' static ,
115+ S : SpawnBlocking + Send ,
87116 {
88117 async fn batch_execute ( & mut self , query : & str ) -> QueryResult < ( ) > {
89118 let query = query. to_string ( ) ;
@@ -92,7 +121,7 @@ mod implementation {
92121 }
93122 }
94123
95- impl < C , MD , O > AsyncConnection for SyncConnectionWrapper < C >
124+ impl < C , S , MD , O > AsyncConnection for SyncConnectionWrapper < C , S >
96125 where
97126 // Backend bounds
98127 <C as Connection >:: Backend : std:: default:: Default + DieselReserveSpecialization ,
@@ -108,6 +137,8 @@ mod implementation {
108137 O : ' static + Send + for < ' conn > diesel:: row:: Row < ' conn , C :: Backend > ,
109138 for < ' conn , ' query > <C as LoadConnection >:: Row < ' conn , ' query > :
110139 IntoOwnedRow < ' conn , <C as Connection >:: Backend , OwnedRow = O > ,
140+ // SpawnBlocking bounds
141+ S : SpawnBlocking + Send ,
111142 {
112143 type LoadFuture < ' conn , ' query > = BoxFuture < ' query , QueryResult < Self :: Stream < ' conn , ' query > > > ;
113144 type ExecuteFuture < ' conn , ' query > = BoxFuture < ' query , QueryResult < usize > > ;
@@ -118,10 +149,12 @@ mod implementation {
118149
119150 async fn establish ( database_url : & str ) -> ConnectionResult < Self > {
120151 let database_url = database_url. to_string ( ) ;
121- tokio:: task:: spawn_blocking ( move || C :: establish ( & database_url) )
152+ let mut runtime = S :: get_runtime ( ) ;
153+
154+ runtime. spawn_blocking ( move || C :: establish ( & database_url) )
122155 . await
123156 . unwrap_or_else ( |e| Err ( diesel:: ConnectionError :: BadConnection ( e. to_string ( ) ) ) )
124- . map ( |c| SyncConnectionWrapper :: new ( c ) )
157+ . map ( move |c| SyncConnectionWrapper :: with_runtime ( c , runtime ) )
125158 }
126159
127160 fn load < ' conn , ' query , T > ( & ' conn mut self , source : T ) -> Self :: LoadFuture < ' conn , ' query >
@@ -209,44 +242,60 @@ mod implementation {
209242 /// A wrapper of a diesel transaction manager usable in async context.
210243 pub struct SyncTransactionManagerWrapper < T > ( PhantomData < T > ) ;
211244
212- impl < T , C > TransactionManager < SyncConnectionWrapper < C > > for SyncTransactionManagerWrapper < T >
245+ impl < T , C , S > TransactionManager < SyncConnectionWrapper < C , S > > for SyncTransactionManagerWrapper < T >
213246 where
214- SyncConnectionWrapper < C > : AsyncConnection ,
247+ SyncConnectionWrapper < C , S > : AsyncConnection ,
215248 C : Connection + ' static ,
249+ S : SpawnBlocking ,
216250 T : diesel:: connection:: TransactionManager < C > + Send ,
217251 {
218252 type TransactionStateData = T :: TransactionStateData ;
219253
220- async fn begin_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
254+ async fn begin_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
221255 conn. spawn_blocking ( move |inner| T :: begin_transaction ( inner) )
222256 . await
223257 }
224258
225- async fn commit_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
259+ async fn commit_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
226260 conn. spawn_blocking ( move |inner| T :: commit_transaction ( inner) )
227261 . await
228262 }
229263
230- async fn rollback_transaction ( conn : & mut SyncConnectionWrapper < C > ) -> QueryResult < ( ) > {
264+ async fn rollback_transaction ( conn : & mut SyncConnectionWrapper < C , S > ) -> QueryResult < ( ) > {
231265 conn. spawn_blocking ( move |inner| T :: rollback_transaction ( inner) )
232266 . await
233267 }
234268
235269 fn transaction_manager_status_mut (
236- conn : & mut SyncConnectionWrapper < C > ,
270+ conn : & mut SyncConnectionWrapper < C , S > ,
237271 ) -> & mut TransactionManagerStatus {
238272 T :: transaction_manager_status_mut ( conn. exclusive_connection ( ) )
239273 }
240274 }
241275
242- impl < C > SyncConnectionWrapper < C > {
276+ impl < C , S > SyncConnectionWrapper < C , S > {
243277 /// Builds a wrapper with this underlying sync connection
244278 pub fn new ( connection : C ) -> Self
245279 where
246280 C : Connection ,
281+ S : SpawnBlocking ,
282+ {
283+ SyncConnectionWrapper {
284+ inner : Arc :: new ( Mutex :: new ( connection) ) ,
285+ runtime : S :: get_runtime ( ) ,
286+ }
287+ }
288+
289+ /// Builds a wrapper with this underlying sync connection
290+ /// and runtime for spawning blocking tasks
291+ pub fn with_runtime ( connection : C , runtime : S ) -> Self
292+ where
293+ C : Connection ,
294+ S : SpawnBlocking ,
247295 {
248296 SyncConnectionWrapper {
249297 inner : Arc :: new ( Mutex :: new ( connection) ) ,
298+ runtime,
250299 }
251300 }
252301
@@ -283,17 +332,18 @@ mod implementation {
283332 where
284333 C : Connection + ' static ,
285334 R : Send + ' static ,
335+ S : SpawnBlocking ,
286336 {
287337 let inner = self . inner . clone ( ) ;
288- tokio :: task :: spawn_blocking ( move || {
338+ self . runtime . spawn_blocking ( move || {
289339 let mut inner = inner. lock ( ) . unwrap_or_else ( |poison| {
290340 // try to be resilient by providing the guard
291341 inner. clear_poison ( ) ;
292342 poison. into_inner ( )
293343 } ) ;
294344 task ( & mut inner)
295345 } )
296- . unwrap_or_else ( |err| QueryResult :: Err ( from_tokio_join_error ( err) ) )
346+ . unwrap_or_else ( |err| QueryResult :: Err ( from_spawn_blocking_error ( err) ) )
297347 . boxed ( )
298348 }
299349
@@ -316,6 +366,8 @@ mod implementation {
316366 // Arguments/Return bounds
317367 Q : QueryFragment < C :: Backend > + QueryId ,
318368 R : Send + ' static ,
369+ // SpawnBlocking bounds
370+ S : SpawnBlocking ,
319371 {
320372 let backend = C :: Backend :: default ( ) ;
321373
@@ -383,4 +435,43 @@ mod implementation {
383435 Self :: TransactionManager :: is_broken_transaction_manager ( self )
384436 }
385437 }
438+
439+ #[ cfg( feature = "tokio" ) ]
440+ pub enum Tokio {
441+ Handle ( tokio:: runtime:: Handle ) ,
442+ Runtime ( tokio:: runtime:: Runtime )
443+ }
444+
445+ #[ cfg( feature = "tokio" ) ]
446+ impl SpawnBlocking for Tokio {
447+ fn spawn_blocking < ' a , R > (
448+ & mut self ,
449+ task : impl FnOnce ( ) -> R + Send + ' static ,
450+ ) -> BoxFuture < ' a , Result < R , Box < dyn Error + Send + Sync + ' static > > >
451+ where
452+ R : Send + ' static ,
453+ {
454+ let fut = match self {
455+ Tokio :: Handle ( handle) => handle. spawn_blocking ( task) ,
456+ Tokio :: Runtime ( runtime) => runtime. spawn_blocking ( task)
457+ } ;
458+
459+ fut
460+ . map_err ( |err| Box :: from ( err) )
461+ . boxed ( )
462+ }
463+
464+ fn get_runtime ( ) -> Self {
465+ if let Ok ( handle) = tokio:: runtime:: Handle :: try_current ( ) {
466+ Tokio :: Handle ( handle)
467+ } else {
468+ let runtime = tokio:: runtime:: Builder :: new_current_thread ( )
469+ . enable_io ( )
470+ . build ( )
471+ . unwrap ( ) ;
472+
473+ Tokio :: Runtime ( runtime)
474+ }
475+ }
476+ }
386477}
0 commit comments