Skip to content

Commit e231a5f

Browse files
committed
Fix: Use LocalSet and spawn_local for single threading
Signed-off-by: Aditya <aditya.salunkh919@gmail.com>
1 parent b3e7835 commit e231a5f

File tree

8 files changed

+79
-134
lines changed

8 files changed

+79
-134
lines changed

examples/mysql/todos/src/lib.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,24 @@ struct Component;
101101

102102
impl wasip3::exports::cli::run::Guest for Component {
103103
async fn run() -> Result<(), ()> {
104-
if let Err(err) = run().await {
105-
let (mut tx, rx) = wasip3::wit_stream::new();
106-
107-
futures::join!(
108-
async { wasip3::cli::stderr::write_via_stream(rx).await.unwrap() },
109-
async {
110-
let remaining = tx.write_all(format!("{err:#}\n").into_bytes()).await;
111-
assert!(remaining.is_empty());
112-
drop(tx);
104+
tokio::task::LocalSet::new()
105+
.run_until(async {
106+
if let Err(err) = run().await {
107+
let (mut tx, rx) = wasip3::wit_stream::new();
108+
109+
futures::join!(
110+
async { wasip3::cli::stderr::write_via_stream(rx).await.unwrap() },
111+
async {
112+
let remaining = tx.write_all(format!("{err:#}\n").into_bytes()).await;
113+
assert!(remaining.is_empty());
114+
drop(tx);
115+
}
116+
);
117+
Err(())
118+
} else {
119+
Ok(())
113120
}
114-
);
115-
Err(())
116-
} else {
117-
Ok(())
118-
}
121+
})
122+
.await
119123
}
120124
}

sqlx-core/src/migrate/source.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl MigrationSource<'static> for PathBuf {
4242
}
4343

4444
#[cfg(target_arch = "wasm32")]
45-
pub async fn resolve(path: &Path) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
45+
pub async fn resolve(path: &PathBuf) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
4646
todo!();
4747
}
4848

@@ -59,11 +59,12 @@ impl<'s, S: Debug + Into<PathBuf> + Send + 's> MigrationSource<'s> for ResolveWi
5959
Box::pin(async move {
6060
let path = self.0.into();
6161
let config = self.1;
62-
62+
#[cfg(not(target_arch = "wasm32"))]
6363
let migrations_with_paths =
6464
crate::rt::spawn_blocking(move || resolve_blocking_with_config(&path, &config))
6565
.await?;
66-
66+
#[cfg(target_arch = "wasm32")]
67+
let migrations_with_paths = resolve(&path).await?;
6768
Ok(migrations_with_paths.into_iter().map(|(m, _p)| m).collect())
6869
})
6970
}

sqlx-core/src/net/socket/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ pub async fn connect_tcp<Ws: WithSocket>(
199199
.await);
200200
}
201201

202+
#[cfg(target_arch = "wasm32")]
203+
{
204+
todo!("outer socket impl")
205+
}
206+
202207
cfg_if! {
203208
if #[cfg(feature = "_rt-async-io")] {
204209
Ok(with_socket.with_socket(connect_tcp_async_io(host, port).await?).await)

sqlx-core/src/rt/mod.rs

Lines changed: 35 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,9 @@ pub enum JoinHandle<T> {
2323
#[cfg(feature = "_rt-async-std")]
2424
AsyncStd(async_std::task::JoinHandle<T>),
2525

26-
#[cfg(feature = "_rt-tokio")]
26+
#[cfg(any(feature = "_rt-tokio", target_arch = "wasm32"))]
2727
Tokio(tokio::task::JoinHandle<T>),
2828

29-
// WASI P3 runtime
30-
#[cfg(target_arch = "wasm32")]
31-
Wasip3(crate::rt::rt_wasip3::JoinHandle<T>),
32-
3329
// Implementation shared by `smol` and `async-global-executor`
3430
#[cfg(feature = "_rt-async-task")]
3531
AsyncTask(Option<async_task::Task<T>>),
@@ -39,37 +35,18 @@ pub enum JoinHandle<T> {
3935
}
4036

4137
pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, TimeoutError> {
42-
#[cfg(all(feature = "_rt-tokio", target_arch = "wasm32"))]
43-
{
44-
let timeout_future = wasip3::clocks::monotonic_clock::wait_for(
45-
duration.as_nanos().try_into().unwrap_or(u64::MAX),
46-
);
47-
let mut timeout = core::pin::pin!(timeout_future);
48-
let mut f = core::pin::pin!(f);
49-
50-
return core::future::poll_fn(|cx| match timeout.as_mut().poll(cx) {
51-
Poll::Ready(_) => Poll::Ready(Err(TimeoutError)),
52-
Poll::Pending => f.as_mut().poll(cx).map(Ok),
53-
})
54-
.await;
38+
#[cfg(debug_assertions)]
39+
let f = Box::pin(f);
40+
41+
#[cfg(feature = "_rt-tokio")]
42+
if rt_tokio::available() {
43+
return tokio::time::timeout(duration, f)
44+
.await
45+
.map_err(|_| TimeoutError);
5546
}
5647

5748
cfg_if! {
58-
if #[cfg(feature = "_rt-tokio")] {
59-
if rt_tokio::available() {
60-
tokio::time::timeout(duration, f)
61-
.await
62-
.map_err(|_| TimeoutError)
63-
} else {
64-
cfg_if! {
65-
if #[cfg(feature = "_rt-async-io")] {
66-
rt_async_io::timeout(duration, f).await
67-
} else {
68-
missing_rt((duration, f))
69-
}
70-
}
71-
}
72-
} else if #[cfg(feature = "_rt-async-io")] {
49+
if #[cfg(feature = "_rt-async-io")] {
7350
rt_async_io::timeout(duration, f).await
7451
} else {
7552
missing_rt((duration, f))
@@ -80,47 +57,33 @@ pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, T
8057
pub async fn sleep(duration: Duration) {
8158
#[cfg(target_arch = "wasm32")]
8259
{
83-
wasip3::clocks::monotonic_clock::wait_for(
60+
return crate::rt::rt_wasip3::spawn(wasip3::clocks::monotonic_clock::wait_for(
8461
duration.as_nanos().try_into().unwrap_or(u64::MAX),
85-
)
62+
))
8663
.await;
87-
return;
64+
}
65+
66+
#[cfg(feature = "_rt-tokio")]
67+
if rt_tokio::available() {
68+
return tokio::time::sleep(duration).await;
8869
}
8970

9071
cfg_if! {
91-
if #[cfg(feature = "_rt-tokio")] {
92-
if rt_tokio::available() {
93-
tokio::time::sleep(duration).await
94-
} else {
95-
cfg_if! {
96-
if #[cfg(feature = "_rt-async-io")] {
97-
rt_async_io::sleep(duration).await
98-
} else {
99-
#[cfg(not(any(feature = "_rt-async-std", target_arch = "wasm32")))]
100-
missing_rt(duration)
101-
}
102-
}
103-
}
104-
} else if #[cfg(feature = "_rt-async-io")] {
72+
if #[cfg(feature = "_rt-async-io")] {
10573
rt_async_io::sleep(duration).await
10674
} else {
107-
#[cfg(not(any(feature = "_rt-async-std", target_arch = "wasm32")))]
10875
missing_rt(duration)
10976
}
11077
}
11178
}
11279

80+
#[cfg(not(target_arch = "wasm32"))]
11381
#[track_caller]
11482
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
11583
where
11684
F: Future + Send + 'static,
11785
F::Output: Send + 'static,
11886
{
119-
#[cfg(all(feature = "_rt-tokio", target_arch = "wasm32"))]
120-
{
121-
return JoinHandle::Wasip3(crate::rt::rt_wasip3::spawn(fut));
122-
}
123-
12487
#[cfg(feature = "_rt-tokio")]
12588
if let Ok(handle) = tokio::runtime::Handle::try_current() {
12689
return JoinHandle::Tokio(handle.spawn(fut));
@@ -139,16 +102,14 @@ where
139102
}
140103
}
141104

142-
#[cfg(all(feature = "_rt-tokio", target_arch = "wasm32"))]
105+
#[cfg(target_arch = "wasm32")]
143106
#[track_caller]
144-
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
107+
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
145108
where
146-
F: FnOnce() -> R + Send + 'static,
147-
R: Send + 'static,
109+
F: Future + 'static,
110+
F::Output: 'static,
148111
{
149-
JoinHandle::Wasip3(crate::rt::rt_wasip3::spawn(
150-
crate::rt::rt_wasip3::spawn_blocking(f),
151-
))
112+
JoinHandle::Tokio(tokio::task::spawn_local(fut))
152113
}
153114

154115
#[cfg(not(target_arch = "wasm32"))]
@@ -163,25 +124,16 @@ where
163124
return JoinHandle::Tokio(handle.spawn_blocking(f));
164125
}
165126

166-
cfg_if! {
167-
if #[cfg(feature = "_rt-async-global-executor")] {
168-
JoinHandle::AsyncTask(Some(async_global_executor::spawn_blocking(f)))
169-
} else if #[cfg(feature = "_rt-smol")] {
170-
JoinHandle::AsyncTask(Some(smol::unblock(f)))
171-
} else if #[cfg(feature = "_rt-async-std")] {
172-
JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
173-
} else {
174-
missing_rt(f)
175-
}
127+
#[cfg(feature = "_rt-async-std")]
128+
{
129+
JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
176130
}
131+
132+
#[cfg(not(feature = "_rt-async-std"))]
133+
missing_rt(f)
177134
}
178135

179136
pub async fn yield_now() {
180-
#[cfg(all(feature = "_rt-tokio", target_arch = "wasm32"))]
181-
{
182-
return crate::rt::rt_wasip3::yield_now().await;
183-
}
184-
185137
#[cfg(feature = "_rt-tokio")]
186138
if rt_tokio::available() {
187139
return tokio::task::yield_now().await;
@@ -210,15 +162,14 @@ pub async fn yield_now() {
210162
.await
211163
}
212164

213-
#[cfg(not(target_arch = "wasm32"))]
214165
#[track_caller]
215166
pub fn test_block_on<F: Future>(f: F) -> F::Output {
216167
#[cfg(feature = "_rt-async-io")]
217168
{
218169
return async_io::block_on(f);
219170
}
220171

221-
#[cfg(feature = "_rt-tokio")]
172+
#[cfg(any(feature = "_rt-tokio", target_arch = "wasm32"))]
222173
{
223174
return tokio::runtime::Builder::new_current_thread()
224175
.enable_all()
@@ -230,7 +181,7 @@ pub fn test_block_on<F: Future>(f: F) -> F::Output {
230181
#[cfg(all(
231182
feature = "_rt-async-std",
232183
not(feature = "_rt-async-io"),
233-
not(feature = "_rt-tokio")
184+
not(any(feature = "_rt-tokio", target_arch = "wasm32"))
234185
))]
235186
{
236187
return async_std::task::block_on(f);
@@ -239,22 +190,14 @@ pub fn test_block_on<F: Future>(f: F) -> F::Output {
239190
#[cfg(not(any(
240191
feature = "_rt-async-io",
241192
feature = "_rt-async-std",
242-
feature = "_rt-tokio"
193+
feature = "_rt-tokio",
194+
target_arch = "wasm32",
243195
)))]
244196
{
245197
missing_rt(f)
246198
}
247199
}
248200

249-
#[cfg(target_arch = "wasm32")]
250-
#[track_caller]
251-
pub fn test_block_on<F: Future + 'static>(f: F) -> F::Output
252-
where
253-
F::Output: 'static,
254-
{
255-
wasip3::wit_bindgen::rt::async_support::block_on(f)
256-
}
257-
258201
#[track_caller]
259202
pub const fn missing_rt<T>(_unused: T) -> ! {
260203
if cfg!(feature = "_rt-tokio") {
@@ -279,14 +222,11 @@ impl<T: Send + 'static> Future for JoinHandle<T> {
279222
.expect("BUG: task taken")
280223
.poll(cx),
281224

282-
#[cfg(feature = "_rt-tokio")]
225+
#[cfg(any(feature = "_rt-tokio", target_arch = "wasm32"))]
283226
Self::Tokio(handle) => Pin::new(handle)
284227
.poll(cx)
285228
.map(|res| res.expect("spawned task panicked")),
286229

287-
#[cfg(target_arch = "wasm32")]
288-
Self::Wasip3(handle) => Pin::new(handle).poll(cx),
289-
290230
Self::_Phantom(_) => {
291231
let _ = cx;
292232
unreachable!("runtime should have been checked on spawn")

sqlx-core/src/rt/rt_wasip3/mod.rs

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,6 @@ impl<T> Future for JoinHandle<T> {
2727
}
2828
}
2929

30-
pub async fn yield_now() {
31-
wasip3::wit_bindgen::yield_async().await;
32-
}
33-
34-
pub fn spawn_blocking<F, R>(f: F) -> impl Future<Output = R>
35-
where
36-
F: FnOnce() -> R + Send + 'static,
37-
R: Send + 'static,
38-
{
39-
async move {
40-
wasip3::wit_bindgen::yield_blocking();
41-
f()
42-
}
43-
}
44-
4530
pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
4631
let (tx, rx) = oneshot::channel();
4732
async_support::spawn(async move {
@@ -126,7 +111,7 @@ pub async fn connect_tcp<Ws: WithSocket>(
126111
let (mut send_tx, send_rx) = wasip3::wit_stream::new();
127112
let (mut recv_rx, recv_fut) = sock.receive();
128113

129-
let task = tokio::task::spawn(async move {
114+
let task = tokio::task::spawn_local(async move {
130115
let sock = Arc::new(sock);
131116

132117
let (ready_tx, ready_rx) = oneshot::channel();

sqlx-core/src/testing/mod.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,15 @@ where
105105
DB::Connection: Migrate,
106106
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
107107
Fut: Future + 'static,
108-
Fut::Output: TestTermination, <Fut as futures_core::Future>::Output: 'static
108+
Fut::Output: TestTermination,
109+
<Fut as futures_core::Future>::Output: 'static,
109110
{
110111
type Output = Fut::Output;
111112

112-
fn run_test(self, args: TestArgs) -> Self::Output where <Fut as futures_core::Future>::Output: 'static {
113+
fn run_test(self, args: TestArgs) -> Self::Output
114+
where
115+
<Fut as futures_core::Future>::Output: 'static,
116+
{
113117
run_test_with_pool(args, move |pool| async move {
114118
let conn = pool
115119
.acquire()
@@ -132,7 +136,10 @@ where
132136
{
133137
type Output = Fut::Output;
134138

135-
fn run_test(self, args: TestArgs) -> Self::Output where <Fut as futures_core::Future>::Output: 'static {
139+
fn run_test(self, args: TestArgs) -> Self::Output
140+
where
141+
<Fut as futures_core::Future>::Output: 'static,
142+
{
136143
run_test(args, self)
137144
}
138145
}
@@ -189,7 +196,8 @@ where
189196
for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>,
190197
F: FnOnce(Pool<DB>) -> Fut + 'static,
191198
Fut: Future,
192-
Fut::Output: TestTermination, <Fut as futures_core::Future>::Output: 'static
199+
Fut::Output: TestTermination,
200+
<Fut as futures_core::Future>::Output: 'static,
193201
{
194202
let test_path = args.test_path;
195203
run_test::<DB, _, _>(args, move |pool_opts, connect_opts| async move {

sqlx-mysql/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ whoami = { version = "2.0.0-pre", default-features = false }
7272
serde = { version = "1.0.144", optional = true }
7373

7474
[dev-dependencies]
75-
sqlx = { workspace = true, features = ["mysql"] }
75+
sqlx = { workspace = true, features = ["mysql","runtime-tokio"] }
7676

7777
[lints]
7878
workspace = true

0 commit comments

Comments
 (0)