wasmtime_wasi_threads/
lib.rs1use std::panic::{AssertUnwindSafe, catch_unwind};
6use std::sync::Arc;
7use std::sync::atomic::{AtomicI32, Ordering};
8use std::thread;
9use wasmtime::{
10 Caller, ExternType, InstancePre, Linker, Module, Result, SharedMemory, Store, format_err,
11};
12
13const WASI_ENTRY_POINT: &str = "wasi_thread_start";
16
17pub struct WasiThreadsCtx<T> {
18 instance_pre: Arc<InstancePre<T>>,
19 tid: AtomicI32,
20 use_async: bool,
21}
22
23impl<T: Clone + Send + 'static> WasiThreadsCtx<T> {
24 pub fn new(module: Module, linker: Arc<Linker<T>>, use_async: bool) -> Result<Self> {
25 let instance_pre = Arc::new(linker.instantiate_pre(&module)?);
26 let tid = AtomicI32::new(0);
27 Ok(Self {
28 instance_pre,
29 tid,
30 use_async,
31 })
32 }
33
34 pub fn spawn(&self, host: T, thread_start_arg: i32) -> Result<i32> {
35 let instance_pre = self.instance_pre.clone();
36
37 if !has_entry_point(instance_pre.module()) {
47 log::error!(
48 "failed to find a wasi-threads entry point function; expected an export with name: {WASI_ENTRY_POINT}"
49 );
50 return Ok(-1);
51 }
52 if !has_correct_signature(instance_pre.module()) {
53 log::error!(
54 "the exported entry point function has an incorrect signature: expected `(i32, i32) -> ()`"
55 );
56 return Ok(-1);
57 }
58
59 let wasi_thread_id = self.next_thread_id();
60 if wasi_thread_id.is_none() {
61 log::error!("ran out of valid thread IDs");
62 return Ok(-1);
63 }
64 let wasi_thread_id = wasi_thread_id.unwrap();
65
66 let builder = thread::Builder::new().name(format!("wasi-thread-{wasi_thread_id}"));
68 let use_async = self.use_async;
69 builder.spawn(move || {
70 let result = catch_unwind(AssertUnwindSafe(|| {
73 let mut store = Store::new(&instance_pre.module().engine(), host);
75
76 let instance = if use_async {
77 wasmtime_wasi::runtime::in_tokio(instance_pre.instantiate_async(&mut store))
78 } else {
79 instance_pre.instantiate(&mut store)
80 }
81 .unwrap();
82
83 let thread_entry_point = instance
84 .get_typed_func::<(i32, i32), ()>(&mut store, WASI_ENTRY_POINT)
85 .unwrap();
86
87 log::trace!(
93 "spawned thread id = {wasi_thread_id}; calling start function `{WASI_ENTRY_POINT}` with: {thread_start_arg}"
94 );
95 let res = if use_async {
96 wasmtime_wasi::runtime::in_tokio(
97 thread_entry_point
98 .call_async(&mut store, (wasi_thread_id, thread_start_arg)),
99 )
100 } else {
101 thread_entry_point.call(&mut store, (wasi_thread_id, thread_start_arg))
102 };
103 match res {
104 Ok(_) => log::trace!("exiting thread id = {wasi_thread_id} normally"),
105 Err(e) => {
106 log::trace!("exiting thread id = {wasi_thread_id} due to error");
107 let e = wasi_common::maybe_exit_on_error(e);
108 eprintln!("Error: {e:?}");
109 std::process::exit(1);
110 }
111 }
112 }));
113
114 if let Err(e) = result {
115 eprintln!("wasi-thread-{wasi_thread_id} panicked: {e:?}");
116 std::process::exit(1);
117 }
118 })?;
119
120 Ok(wasi_thread_id)
121 }
122
123 fn next_thread_id(&self) -> Option<i32> {
129 match self
130 .tid
131 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| match v {
132 ..=0x1ffffffe => Some(v + 1),
133 _ => None,
134 }) {
135 Ok(v) => Some(v + 1),
136 Err(_) => None,
137 }
138 }
139}
140
141pub fn add_to_linker<T: Clone + Send + 'static>(
147 linker: &mut wasmtime::Linker<T>,
148 store: &wasmtime::Store<T>,
149 module: &Module,
150 get_cx: impl Fn(&mut T) -> &WasiThreadsCtx<T> + Send + Sync + Copy + 'static,
151) -> wasmtime::Result<()> {
152 linker.func_wrap(
153 "wasi",
154 "thread-spawn",
155 move |mut caller: Caller<'_, T>, start_arg: i32| -> i32 {
156 log::trace!("new thread requested via `wasi::thread_spawn` call");
157 let host = caller.data().clone();
158 let ctx = get_cx(caller.data_mut());
159 match ctx.spawn(host, start_arg) {
160 Ok(thread_id) => {
161 assert!(thread_id >= 0, "thread_id = {thread_id}");
162 thread_id
163 }
164 Err(e) => {
165 log::error!("failed to spawn thread: {e}");
166 -1
167 }
168 }
169 },
170 )?;
171
172 for import in module.imports() {
175 if let Some(m) = import.ty().memory() {
176 if m.is_shared() {
177 let mem = SharedMemory::new(module.engine(), m.clone())?;
178 linker.define(store, import.module(), import.name(), mem.clone())?;
179 } else {
180 return Err(format_err!(
181 "memory was not shared; a `wasi-threads` must import \
182 a shared memory as \"memory\""
183 ));
184 }
185 }
186 }
187 Ok(())
188}
189
190fn has_entry_point(module: &Module) -> bool {
192 module.get_export(WASI_ENTRY_POINT).is_some()
193}
194
195fn has_correct_signature(module: &Module) -> bool {
197 match module.get_export(WASI_ENTRY_POINT) {
198 Some(ExternType::Func(ty)) => {
199 ty.params().len() == 2
200 && ty.params().nth(0).unwrap().is_i32()
201 && ty.params().nth(1).unwrap().is_i32()
202 && ty.results().len() == 0
203 }
204 _ => false,
205 }
206}