diff --git a/Cargo.toml b/Cargo.toml index a313b6e..43e0ae5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,9 +15,11 @@ readme = "README.md" [features] default = ["es_modules"] es_modules = [] +keep_worker_alive = [] [dependencies] wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4.42" web-sys = { version = "0.3", features = [ "Blob", "DedicatedWorkerGlobalScope", @@ -32,6 +34,8 @@ web-sys = { version = "0.3", features = [ ] } js-sys = "0.3" futures = "0.3" +async-std = "1.12.0" +once_cell = "1.8" [dev-dependencies] log = "0.4" diff --git a/src/wasm32/js/web_worker.js b/src/wasm32/js/web_worker.js index 7f02ee4..e5a0269 100644 --- a/src/wasm32/js/web_worker.js +++ b/src/wasm32/js/web_worker.js @@ -5,7 +5,7 @@ importScripts('WASM_BINDGEN_SHIM_URL'); // Once we've got it, initialize it all with the `wasm_bindgen` global we imported via // `importScripts`. self.onmessage = event => { - let [ module, memory, work ] = event.data; + let [ module, memory, work, thread_key ] = event.data; wasm_bindgen(module, memory).catch(err => { console.log(err); @@ -20,9 +20,6 @@ self.onmessage = event => { // Enter rust code by calling entry point defined in `lib.rs`. // This executes closure defined by work context. wasm.wasm_thread_entry_point(work); - - // Once done, terminate web worker - close(); }); }; \ No newline at end of file diff --git a/src/wasm32/js/web_worker_module.js b/src/wasm32/js/web_worker_module.js index 84945b9..375a5d4 100644 --- a/src/wasm32/js/web_worker_module.js +++ b/src/wasm32/js/web_worker_module.js @@ -5,7 +5,7 @@ import init, {wasm_thread_entry_point} from "WASM_BINDGEN_SHIM_URL"; // Once we've got it, initialize it all with the `wasm_bindgen` global we imported via // `importScripts`. self.onmessage = event => { - let [ module, memory, work ] = event.data; + let [ module, memory, work, thread_key ] = event.data; init(module, memory).catch(err => { console.log(err); @@ -20,8 +20,5 @@ self.onmessage = event => { // Enter rust code by calling entry point defined in `lib.rs`. // This executes closure defined by work context. wasm_thread_entry_point(work); - - // Once done, terminate web worker - close(); }); }; \ No newline at end of file diff --git a/src/wasm32/mod.rs b/src/wasm32/mod.rs index a413b20..a9c5a85 100644 --- a/src/wasm32/mod.rs +++ b/src/wasm32/mod.rs @@ -1,4 +1,4 @@ -pub use std::thread::{current, sleep, Result, Thread, ThreadId}; +pub use std::thread::{Result, Thread}; use std::{ cell::UnsafeCell, fmt, @@ -16,11 +16,17 @@ use utils::SpinLockMutex; pub use utils::{available_parallelism, get_wasm_bindgen_shim_script_path, get_worker_script, is_web_worker_thread}; use wasm_bindgen::prelude::*; use web_sys::{DedicatedWorkerGlobalScope, Worker, WorkerOptions, WorkerType}; +use std::future::Future; +use std::pin::Pin; + mod scoped; mod signal; mod utils; +// Use a thread safe static hashmap to keep track of whether each thread can be closed + + struct WebWorkerContext { func: Box, } @@ -66,6 +72,30 @@ impl WorkerMessage { } } +// Pass `f` in `MaybeUninit` because actually that closure might *run longer than the lifetime of `F`*. +// See for more details. +// To prevent leaks we use a wrapper that drops its contents. +#[repr(transparent)] +struct MaybeDangling(mem::MaybeUninit); +impl MaybeDangling { + fn new(x: T) -> Self { + MaybeDangling(mem::MaybeUninit::new(x)) + } + fn into_inner(self) -> T { + // SAFETY: we are always initiailized. + let ret = unsafe { self.0.assume_init_read() }; + // Make sure we don't drop. + mem::forget(self); + ret + } +} +impl Drop for MaybeDangling { + fn drop(&mut self) { + // SAFETY: we are always initiailized. + unsafe { self.0.assume_init_drop() }; + } +} + static DEFAULT_BUILDER: Mutex> = Mutex::new(None); /// Thread factory, which can be used in order to configure the properties of a new thread. @@ -148,6 +178,7 @@ impl Builder { /// Spawns a new thread by taking ownership of the `Builder`, and returns an /// [std::io::Result] to its [`JoinHandle`]. + pub fn spawn(self, f: F) -> std::io::Result> where F: FnOnce() -> T, @@ -157,6 +188,14 @@ impl Builder { unsafe { self.spawn_unchecked(f) } } + pub fn spawn_async(self, f: F) -> std::io::Result> + where + F: AsyncClosure + Send + 'static, + T: Send + 'static, + { + unsafe { self.spawn_unchecked_async(f) } + } + /// Spawns a new thread without any lifetime restrictions by taking ownership /// of the `Builder`, and returns an [std::io::Result] to its [`JoinHandle`]. /// @@ -180,6 +219,14 @@ impl Builder { Ok(JoinHandle(unsafe { self.spawn_unchecked_(f, None) }?)) } + pub unsafe fn spawn_unchecked_async<'a, F, T>(self, f: F) -> std::io::Result> + where + F: AsyncClosure + Send + 'static, + T: Send + 'static, + { + Ok(JoinHandle(unsafe { self.spawn_unchecked_async_(f, None) }?)) + } + pub(crate) unsafe fn spawn_unchecked_<'a, 'scope, F, T>( self, f: F, @@ -201,36 +248,81 @@ impl Builder { }); let their_packet = my_packet.clone(); - // Pass `f` in `MaybeUninit` because actually that closure might *run longer than the lifetime of `F`*. - // See for more details. - // To prevent leaks we use a wrapper that drops its contents. - #[repr(transparent)] - struct MaybeDangling(mem::MaybeUninit); - impl MaybeDangling { - fn new(x: T) -> Self { - MaybeDangling(mem::MaybeUninit::new(x)) - } - fn into_inner(self) -> T { - // SAFETY: we are always initiailized. - let ret = unsafe { self.0.assume_init_read() }; - // Make sure we don't drop. - mem::forget(self); - ret - } + let f = MaybeDangling::new(f); + let main = Box::new(move || { + // SAFETY: we constructed `f` initialized. + let f = f.into_inner(); + // Execute the closure and catch any panics + let try_result = catch_unwind(AssertUnwindSafe(|| f())); + // SAFETY: `their_packet` as been built just above and moved by the + // closure (it is an Arc<...>) and `my_packet` will be stored in the + // same `JoinInner` as this closure meaning the mutation will be + // safe (not modify it and affect a value far away). + unsafe { *their_packet.result.get() = Some(try_result) }; + // Here `their_packet` gets dropped, and if this is the last `Arc` for that packet that + // will call `decrement_num_running_threads` and therefore signal that this thread is + // done. + drop(their_packet); + // Notify waiting handles + their_signal.signal(); + // Here, the lifetime `'a` and even `'scope` can end, so the thread can be closed. `main` keeps running for a bit + // after that before returning itself. + #[cfg(not(feature = "keep_worker_alive"))] + js_sys::eval("self") + .unwrap() + .dyn_into::() + .unwrap() + .close(); + }); + + // Erase lifetime + let context = WebWorkerContext { + func: mem::transmute::, Box>(main), + }; + + if is_web_worker_thread() { + WorkerMessage::SpawnThread(BuilderRequest { builder: self, context }).post(); + } else { + self.spawn_for_context(context); } - impl Drop for MaybeDangling { - fn drop(&mut self) { - // SAFETY: we are always initiailized. - unsafe { self.0.assume_init_drop() }; - } + + if let Some(scope) = &my_packet.scope { + scope.increment_num_running_threads(); } + Ok(JoinInner { + signal: my_signal, + packet: my_packet, + }) + } + + pub(crate) unsafe fn spawn_unchecked_async_<'a, F, T>( + self, + f: F, + scope_data: Option>, + ) -> std::io::Result> + where + F: AsyncClosure + 'static, + T: Send + 'static, + { + let my_signal = Arc::new(Signal::new()); + let their_signal = my_signal.clone(); + + let my_packet: Arc> = Arc::new(Packet { + scope: scope_data, + result: UnsafeCell::new(None), + _marker: PhantomData, + }); + let their_packet = my_packet.clone(); + let f = MaybeDangling::new(f); - let main = Box::new(move || { + let spawn_local = wasm_bindgen_futures::spawn_local(async move { // SAFETY: we constructed `f` initialized. let f = f.into_inner(); // Execute the closure and catch any panics - let try_result = catch_unwind(AssertUnwindSafe(|| f())); + + let try_result = Ok(f.call_once().await); + // SAFETY: `their_packet` as been built just above and moved by the // closure (it is an Arc<...>) and `my_packet` will be stored in the // same `JoinInner` as this closure meaning the mutation will be @@ -242,9 +334,16 @@ impl Builder { drop(their_packet); // Notify waiting handles their_signal.signal(); - // Here, the lifetime `'a` and even `'scope` can end. `main` keeps running for a bit + // Here, the lifetime `'a` and even `'scope` can end, so the thread can be closed. `main` keeps running for a bit // after that before returning itself. + #[cfg(not(feature = "keep_worker_alive"))] + js_sys::eval("self") + .unwrap() + .dyn_into::() + .unwrap() + .close(); }); + let main = Box::new(move || spawn_local ); // Erase lifetime let context = WebWorkerContext { @@ -453,5 +552,30 @@ where F: Send + 'static, T: Send + 'static, { - Builder::new().spawn(f).expect("failed to spawn thread") + Builder::new().spawn(f).expect("Failed to spawn thread") +} + +pub trait AsyncClosure { + fn call_once(self) -> Pin>>; +} + +impl AsyncClosure for F +where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future + 'static, + T: Send + 'static, +{ + fn call_once(self) -> Pin>> { + Box::pin(self()) + } } + +// JoinHandle is of type () because the future immediately returns. +pub fn spawn_async(f: F) -> JoinHandle +where + F: AsyncClosure, + F: Send + 'static, + T: Send + 'static, +{ + Builder::new().spawn_async(f).expect("failed to spawn thread") +} \ No newline at end of file diff --git a/src/wasm32/scoped.rs b/src/wasm32/scoped.rs index cd18532..fdf1891 100644 --- a/src/wasm32/scoped.rs +++ b/src/wasm32/scoped.rs @@ -9,6 +9,7 @@ use std::{ use super::{signal::Signal, utils::is_web_worker_thread, Builder, JoinInner}; + /// A scope to spawn scoped threads in. /// /// See [`scope`] for details. diff --git a/src/wasm32/utils.rs b/src/wasm32/utils.rs index 11da5e2..f42590f 100644 --- a/src/wasm32/utils.rs +++ b/src/wasm32/utils.rs @@ -103,3 +103,4 @@ impl SpinLockMutex for Mutex { } } } + diff --git a/tests/wasm.rs b/tests/wasm.rs index 9c90ec1..90c7276 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -1,8 +1,12 @@ #![cfg(target_arch = "wasm32")] use core::{ - sync::atomic::{AtomicBool, Ordering}, - time::Duration, + sync::atomic::{AtomicBool, Ordering}, time::Duration +}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, }; use wasm_bindgen_test::*; @@ -121,3 +125,77 @@ async fn thread_async_channel() { let result = main_rx.recv().await.unwrap(); assert_eq!(result, "Hello world!"); } + +//TODO: doesn't fail when keep_worker_alive is enabled. Can threads be closed from wasm? +// This test should fail if "keep_worker_alive" enabled +#[wasm_bindgen_test] +async fn keep_worker_alive(){ + thread::spawn(|| { + wasm_bindgen_futures::spawn_local(async move { + let promise = js_sys::Promise::resolve(&wasm_bindgen::JsValue::from(42)); + wasm_bindgen_futures::JsFuture::from(promise).await.unwrap(); + //additional wait to simulate a js future that takes more time + async_std::task::sleep(std::time::Duration::from_secs(1)).await; + // This should only run if "keep_worker_alive" is enabled. If disabled, + // the thread will close before it can run + assert_eq!(1, 2); + }); + }); +} + +#[wasm_bindgen_test] +async fn spawn_async(){ + let (thread_tx, main_rx) = async_channel::unbounded::(); + //since spawn_async closes the thread once the provided closure is complete, + //"keep_worker_alive" is not necessary + thread::spawn_async(|| async move{ + + let promise = js_sys::Promise::resolve(&wasm_bindgen::JsValue::from(42)); + wasm_bindgen_futures::JsFuture::from(promise).await.unwrap(); + // //additional wait to simulate a js future that takes more time + async_std::task::sleep(std::time::Duration::from_secs(1)).await; + thread_tx.send("After js future".to_string()).await.unwrap(); + }); + let msg = main_rx.recv().await.unwrap(); + + assert_eq!(msg, "After js future"); +} + +struct DelayedValue { + start_time: f64, + delay_time: f64, +} + +impl Future for DelayedValue { + type Output = u32; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Check if the delay has elapsed + let performance_now = js_sys::Date::now(); + if self.start_time+self.delay_time < performance_now { + Poll::Ready(1234) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + } + + } +} + +impl DelayedValue { + pub fn new(duration: f64) -> Self { + let performance_now = js_sys::Date::now(); + DelayedValue {start_time: performance_now, delay_time: duration} + } +} + +#[wasm_bindgen_test] +async fn async_thread_join_async() { + + let handle = thread::spawn_async(|| async move { + DelayedValue::new(1000.0).await + }); + + assert_eq!(handle.join_async().await.unwrap(), 1234); +} +