wasmtime_wasi_threads/
lib.rs1use anyhow::{Result, anyhow};
6use std::panic::{AssertUnwindSafe, catch_unwind};
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI32, Ordering};
9use std::thread;
10use wasmtime::{Caller, ExternType, InstancePre, Linker, Module, SharedMemory, Store};
11
12const WASI_ENTRY_POINT: &str = "wasi_thread_start";
15
16pub struct WasiThreadsCtx<T> {
17 instance_pre: Arc<InstancePre<T>>,
18 tid: AtomicI32,
19}
20
21impl<T: Clone + Send + 'static> WasiThreadsCtx<T> {
22 pub fn new(module: Module, linker: Arc<Linker<T>>) -> Result<Self> {
23 let instance_pre = Arc::new(linker.instantiate_pre(&module)?);
24 let tid = AtomicI32::new(0);
25 Ok(Self { instance_pre, tid })
26 }
27
28 pub fn spawn(&self, host: T, thread_start_arg: i32) -> Result<i32> {
29 let instance_pre = self.instance_pre.clone();
30
31 if !has_entry_point(instance_pre.module()) {
41 log::error!(
42 "failed to find a wasi-threads entry point function; expected an export with name: {WASI_ENTRY_POINT}"
43 );
44 return Ok(-1);
45 }
46 if !has_correct_signature(instance_pre.module()) {
47 log::error!(
48 "the exported entry point function has an incorrect signature: expected `(i32, i32) -> ()`"
49 );
50 return Ok(-1);
51 }
52
53 let wasi_thread_id = self.next_thread_id();
54 if wasi_thread_id.is_none() {
55 log::error!("ran out of valid thread IDs");
56 return Ok(-1);
57 }
58 let wasi_thread_id = wasi_thread_id.unwrap();
59
60 let builder = thread::Builder::new().name(format!("wasi-thread-{wasi_thread_id}"));
62 builder.spawn(move || {
63 let result = catch_unwind(AssertUnwindSafe(|| {
66 let mut store = Store::new(&instance_pre.module().engine(), host);
68
69 let instance = if instance_pre.module().engine().is_async() {
70 wasmtime_wasi::runtime::in_tokio(instance_pre.instantiate_async(&mut store))
71 } else {
72 instance_pre.instantiate(&mut store)
73 }
74 .unwrap();
75
76 let thread_entry_point = instance
77 .get_typed_func::<(i32, i32), ()>(&mut store, WASI_ENTRY_POINT)
78 .unwrap();
79
80 log::trace!(
86 "spawned thread id = {}; calling start function `{}` with: {}",
87 wasi_thread_id,
88 WASI_ENTRY_POINT,
89 thread_start_arg
90 );
91 let res = if instance_pre.module().engine().is_async() {
92 wasmtime_wasi::runtime::in_tokio(
93 thread_entry_point
94 .call_async(&mut store, (wasi_thread_id, thread_start_arg)),
95 )
96 } else {
97 thread_entry_point.call(&mut store, (wasi_thread_id, thread_start_arg))
98 };
99 match res {
100 Ok(_) => log::trace!("exiting thread id = {} normally", wasi_thread_id),
101 Err(e) => {
102 log::trace!("exiting thread id = {} due to error", wasi_thread_id);
103 let e = wasi_common::maybe_exit_on_error(e);
104 eprintln!("Error: {e:?}");
105 std::process::exit(1);
106 }
107 }
108 }));
109
110 if let Err(e) = result {
111 eprintln!("wasi-thread-{wasi_thread_id} panicked: {e:?}");
112 std::process::exit(1);
113 }
114 })?;
115
116 Ok(wasi_thread_id)
117 }
118
119 fn next_thread_id(&self) -> Option<i32> {
125 match self
126 .tid
127 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| match v {
128 ..=0x1ffffffe => Some(v + 1),
129 _ => None,
130 }) {
131 Ok(v) => Some(v + 1),
132 Err(_) => None,
133 }
134 }
135}
136
137pub fn add_to_linker<T: Clone + Send + 'static>(
143 linker: &mut wasmtime::Linker<T>,
144 store: &wasmtime::Store<T>,
145 module: &Module,
146 get_cx: impl Fn(&mut T) -> &WasiThreadsCtx<T> + Send + Sync + Copy + 'static,
147) -> anyhow::Result<()> {
148 linker.func_wrap(
149 "wasi",
150 "thread-spawn",
151 move |mut caller: Caller<'_, T>, start_arg: i32| -> i32 {
152 log::trace!("new thread requested via `wasi::thread_spawn` call");
153 let host = caller.data().clone();
154 let ctx = get_cx(caller.data_mut());
155 match ctx.spawn(host, start_arg) {
156 Ok(thread_id) => {
157 assert!(thread_id >= 0, "thread_id = {thread_id}");
158 thread_id
159 }
160 Err(e) => {
161 log::error!("failed to spawn thread: {}", e);
162 -1
163 }
164 }
165 },
166 )?;
167
168 for import in module.imports() {
171 if let Some(m) = import.ty().memory() {
172 if m.is_shared() {
173 let mem = SharedMemory::new(module.engine(), m.clone())?;
174 linker.define(store, import.module(), import.name(), mem.clone())?;
175 } else {
176 return Err(anyhow!(
177 "memory was not shared; a `wasi-threads` must import \
178 a shared memory as \"memory\""
179 ));
180 }
181 }
182 }
183 Ok(())
184}
185
186fn has_entry_point(module: &Module) -> bool {
188 module.get_export(WASI_ENTRY_POINT).is_some()
189}
190
191fn has_correct_signature(module: &Module) -> bool {
193 match module.get_export(WASI_ENTRY_POINT) {
194 Some(ExternType::Func(ty)) => {
195 ty.params().len() == 2
196 && ty.params().nth(0).unwrap().is_i32()
197 && ty.params().nth(1).unwrap().is_i32()
198 && ty.results().len() == 0
199 }
200 _ => false,
201 }
202}