wasmtime_wasi_threads/
lib.rs

1//! Implement [`wasi-threads`].
2//!
3//! [`wasi-threads`]: https://github.com/WebAssembly/wasi-threads
4
5use anyhow::{Result, anyhow};
6use std::panic::{AssertUnwindSafe, catch_unwind};
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI32, Ordering};
9use std::thread;
10use wasmtime::{Caller, ExternType, InstancePre, Linker, Module, SharedMemory, Store};
11
12// This name is a function export designated by the wasi-threads specification:
13// https://github.com/WebAssembly/wasi-threads/#detailed-design-discussion
14const WASI_ENTRY_POINT: &str = "wasi_thread_start";
15
16pub struct WasiThreadsCtx<T> {
17    instance_pre: Arc<InstancePre<T>>,
18    tid: AtomicI32,
19}
20
21impl<T: Clone + Send + 'static> WasiThreadsCtx<T> {
22    pub fn new(module: Module, linker: Arc<Linker<T>>) -> Result<Self> {
23        let instance_pre = Arc::new(linker.instantiate_pre(&module)?);
24        let tid = AtomicI32::new(0);
25        Ok(Self { instance_pre, tid })
26    }
27
28    pub fn spawn(&self, host: T, thread_start_arg: i32) -> Result<i32> {
29        let instance_pre = self.instance_pre.clone();
30
31        // Check that the thread entry point is present. Why here? If we check
32        // for this too early, then we cannot accept modules that do not have an
33        // entry point but never spawn a thread. As pointed out in
34        // https://github.com/bytecodealliance/wasmtime/issues/6153, checking
35        // the entry point here allows wasi-threads to be compatible with more
36        // modules.
37        //
38        // As defined in the wasi-threads specification, returning a negative
39        // result here indicates to the guest module that the spawn failed.
40        if !has_entry_point(instance_pre.module()) {
41            log::error!(
42                "failed to find a wasi-threads entry point function; expected an export with name: {WASI_ENTRY_POINT}"
43            );
44            return Ok(-1);
45        }
46        if !has_correct_signature(instance_pre.module()) {
47            log::error!(
48                "the exported entry point function has an incorrect signature: expected `(i32, i32) -> ()`"
49            );
50            return Ok(-1);
51        }
52
53        let wasi_thread_id = self.next_thread_id();
54        if wasi_thread_id.is_none() {
55            log::error!("ran out of valid thread IDs");
56            return Ok(-1);
57        }
58        let wasi_thread_id = wasi_thread_id.unwrap();
59
60        // Start a Rust thread running a new instance of the current module.
61        let builder = thread::Builder::new().name(format!("wasi-thread-{wasi_thread_id}"));
62        builder.spawn(move || {
63            // Catch any panic failures in host code; e.g., if a WASI module
64            // were to crash, we want all threads to exit, not just this one.
65            let result = catch_unwind(AssertUnwindSafe(|| {
66                // Each new instance is created in its own store.
67                let mut store = Store::new(&instance_pre.module().engine(), host);
68
69                let instance = if instance_pre.module().engine().is_async() {
70                    wasmtime_wasi::runtime::in_tokio(instance_pre.instantiate_async(&mut store))
71                } else {
72                    instance_pre.instantiate(&mut store)
73                }
74                .unwrap();
75
76                let thread_entry_point = instance
77                    .get_typed_func::<(i32, i32), ()>(&mut store, WASI_ENTRY_POINT)
78                    .unwrap();
79
80                // Start the thread's entry point. Any traps or calls to
81                // `proc_exit`, by specification, should end execution for all
82                // threads. This code uses `process::exit` to do so, which is
83                // what the user expects from the CLI but probably not in a
84                // Wasmtime embedding.
85                log::trace!(
86                    "spawned thread id = {}; calling start function `{}` with: {}",
87                    wasi_thread_id,
88                    WASI_ENTRY_POINT,
89                    thread_start_arg
90                );
91                let res = if instance_pre.module().engine().is_async() {
92                    wasmtime_wasi::runtime::in_tokio(
93                        thread_entry_point
94                            .call_async(&mut store, (wasi_thread_id, thread_start_arg)),
95                    )
96                } else {
97                    thread_entry_point.call(&mut store, (wasi_thread_id, thread_start_arg))
98                };
99                match res {
100                    Ok(_) => log::trace!("exiting thread id = {} normally", wasi_thread_id),
101                    Err(e) => {
102                        log::trace!("exiting thread id = {} due to error", wasi_thread_id);
103                        let e = wasi_common::maybe_exit_on_error(e);
104                        eprintln!("Error: {e:?}");
105                        std::process::exit(1);
106                    }
107                }
108            }));
109
110            if let Err(e) = result {
111                eprintln!("wasi-thread-{wasi_thread_id} panicked: {e:?}");
112                std::process::exit(1);
113            }
114        })?;
115
116        Ok(wasi_thread_id)
117    }
118
119    /// Helper for generating valid WASI thread IDs (TID).
120    ///
121    /// Callers of `wasi_thread_spawn` expect a TID in range of 0 < TID <= 0x1FFFFFFF
122    /// to indicate a successful spawning of the thread whereas a negative
123    /// return value indicates an failure to spawn.
124    fn next_thread_id(&self) -> Option<i32> {
125        match self
126            .tid
127            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| match v {
128                ..=0x1ffffffe => Some(v + 1),
129                _ => None,
130            }) {
131            Ok(v) => Some(v + 1),
132            Err(_) => None,
133        }
134    }
135}
136
137/// Manually add the WASI `thread_spawn` function to the linker.
138///
139/// It is unclear what namespace the `wasi-threads` proposal should live under:
140/// it is not clear if it should be included in any of the `preview*` releases
141/// so for the time being its module namespace is simply `"wasi"` (TODO).
142pub fn add_to_linker<T: Clone + Send + 'static>(
143    linker: &mut wasmtime::Linker<T>,
144    store: &wasmtime::Store<T>,
145    module: &Module,
146    get_cx: impl Fn(&mut T) -> &WasiThreadsCtx<T> + Send + Sync + Copy + 'static,
147) -> anyhow::Result<()> {
148    linker.func_wrap(
149        "wasi",
150        "thread-spawn",
151        move |mut caller: Caller<'_, T>, start_arg: i32| -> i32 {
152            log::trace!("new thread requested via `wasi::thread_spawn` call");
153            let host = caller.data().clone();
154            let ctx = get_cx(caller.data_mut());
155            match ctx.spawn(host, start_arg) {
156                Ok(thread_id) => {
157                    assert!(thread_id >= 0, "thread_id = {thread_id}");
158                    thread_id
159                }
160                Err(e) => {
161                    log::error!("failed to spawn thread: {}", e);
162                    -1
163                }
164            }
165        },
166    )?;
167
168    // Find the shared memory import and satisfy it with a newly-created shared
169    // memory import.
170    for import in module.imports() {
171        if let Some(m) = import.ty().memory() {
172            if m.is_shared() {
173                let mem = SharedMemory::new(module.engine(), m.clone())?;
174                linker.define(store, import.module(), import.name(), mem.clone())?;
175            } else {
176                return Err(anyhow!(
177                    "memory was not shared; a `wasi-threads` must import \
178                     a shared memory as \"memory\""
179                ));
180            }
181        }
182    }
183    Ok(())
184}
185
186/// Check if wasi-threads' `wasi_thread_start` export is present.
187fn has_entry_point(module: &Module) -> bool {
188    module.get_export(WASI_ENTRY_POINT).is_some()
189}
190
191/// Check if the entry function has the correct signature `(i32, i32) -> ()`.
192fn has_correct_signature(module: &Module) -> bool {
193    match module.get_export(WASI_ENTRY_POINT) {
194        Some(ExternType::Func(ty)) => {
195            ty.params().len() == 2
196                && ty.params().nth(0).unwrap().is_i32()
197                && ty.params().nth(1).unwrap().is_i32()
198                && ty.results().len() == 0
199        }
200        _ => false,
201    }
202}