1use crate::component::func::{LiftContext, LowerContext, Options};
2use crate::component::matching::InstanceType;
3use crate::component::storage::slice_to_storage_mut;
4use crate::component::{ComponentNamedList, ComponentType, Lift, Lower, Val};
5use crate::prelude::*;
6use crate::runtime::vm::component::{
7 ComponentInstance, InstanceFlags, VMComponentContext, VMLowering, VMLoweringCallee,
8};
9use crate::runtime::vm::{VMFuncRef, VMGlobalDefinition, VMMemoryDefinition, VMOpaqueContext};
10use crate::{AsContextMut, CallHook, StoreContextMut, ValRaw};
11use alloc::sync::Arc;
12use core::any::Any;
13use core::mem::{self, MaybeUninit};
14use core::ptr::NonNull;
15use wasmtime_environ::component::{
16 CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeFuncIndex,
17 MAX_FLAT_PARAMS, MAX_FLAT_RESULTS,
18};
19
20pub struct HostFunc {
21 entrypoint: VMLoweringCallee,
22 typecheck: Box<dyn (Fn(TypeFuncIndex, &InstanceType<'_>) -> Result<()>) + Send + Sync>,
23 func: Box<dyn Any + Send + Sync>,
24}
25
26impl core::fmt::Debug for HostFunc {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.debug_struct("HostFunc").finish_non_exhaustive()
29 }
30}
31
32impl HostFunc {
33 pub(crate) fn from_closure<T, F, P, R>(func: F) -> Arc<HostFunc>
34 where
35 F: Fn(StoreContextMut<T>, P) -> Result<R> + Send + Sync + 'static,
36 P: ComponentNamedList + Lift + 'static,
37 R: ComponentNamedList + Lower + 'static,
38 {
39 let entrypoint = Self::entrypoint::<T, F, P, R>;
40 Arc::new(HostFunc {
41 entrypoint,
42 typecheck: Box::new(typecheck::<P, R>),
43 func: Box::new(func),
44 })
45 }
46
47 extern "C" fn entrypoint<T, F, P, R>(
48 cx: NonNull<VMOpaqueContext>,
49 data: NonNull<u8>,
50 ty: u32,
51 _caller_instance: u32,
52 flags: NonNull<VMGlobalDefinition>,
53 memory: *mut VMMemoryDefinition,
54 realloc: *mut VMFuncRef,
55 string_encoding: u8,
56 async_: u8,
57 storage: NonNull<MaybeUninit<ValRaw>>,
58 storage_len: usize,
59 ) -> bool
60 where
61 F: Fn(StoreContextMut<T>, P) -> Result<R>,
62 P: ComponentNamedList + Lift + 'static,
63 R: ComponentNamedList + Lower + 'static,
64 {
65 let data = data.as_ptr() as *const F;
66 unsafe {
67 call_host_and_handle_result::<T>(cx, |instance, types, store| {
68 call_host::<_, _, _, _>(
69 instance,
70 types,
71 store,
72 TypeFuncIndex::from_u32(ty),
73 InstanceFlags::from_raw(flags),
74 memory,
75 realloc,
76 StringEncoding::from_u8(string_encoding).unwrap(),
77 async_ != 0,
78 NonNull::slice_from_raw_parts(storage, storage_len).as_mut(),
79 |store, args| (*data)(store, args),
80 )
81 })
82 }
83 }
84
85 pub(crate) fn new_dynamic<T, F>(func: F) -> Arc<HostFunc>
86 where
87 F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static,
88 {
89 Arc::new(HostFunc {
90 entrypoint: dynamic_entrypoint::<T, F>,
91 typecheck: Box::new(move |_expected_index, _expected_types| Ok(())),
95 func: Box::new(func),
96 })
97 }
98
99 pub fn typecheck(&self, ty: TypeFuncIndex, types: &InstanceType<'_>) -> Result<()> {
100 (self.typecheck)(ty, types)
101 }
102
103 pub fn lowering(&self) -> VMLowering {
104 let data = NonNull::from(&*self.func).cast();
105 VMLowering {
106 callee: self.entrypoint,
107 data: data.into(),
108 }
109 }
110}
111
112fn typecheck<P, R>(ty: TypeFuncIndex, types: &InstanceType<'_>) -> Result<()>
113where
114 P: ComponentNamedList + Lift,
115 R: ComponentNamedList + Lower,
116{
117 let ty = &types.types[ty];
118 P::typecheck(&InterfaceType::Tuple(ty.params), types)
119 .context("type mismatch with parameters")?;
120 R::typecheck(&InterfaceType::Tuple(ty.results), types).context("type mismatch with results")?;
121 Ok(())
122}
123
124unsafe fn call_host<T, Params, Return, F>(
146 instance: *mut ComponentInstance,
147 types: &Arc<ComponentTypes>,
148 mut cx: StoreContextMut<'_, T>,
149 ty: TypeFuncIndex,
150 mut flags: InstanceFlags,
151 memory: *mut VMMemoryDefinition,
152 realloc: *mut VMFuncRef,
153 string_encoding: StringEncoding,
154 async_: bool,
155 storage: &mut [MaybeUninit<ValRaw>],
156 closure: F,
157) -> Result<()>
158where
159 Params: Lift,
160 Return: Lower,
161 F: FnOnce(StoreContextMut<'_, T>, Params) -> Result<Return>,
162{
163 if async_ {
164 todo!()
165 }
166
167 #[repr(C)]
171 struct ReturnPointer<T> {
172 args: T,
173 retptr: ValRaw,
174 }
175
176 #[repr(C)]
180 union ReturnStack<T: Copy, U: Copy> {
181 args: T,
182 ret: U,
183 }
184
185 let options = Options::new(
186 cx.0.id(),
187 NonNull::new(memory),
188 NonNull::new(realloc),
189 string_encoding,
190 );
191
192 if !flags.may_leave() {
196 bail!("cannot leave component instance");
197 }
198
199 let ty = &types[ty];
200 let param_tys = InterfaceType::Tuple(ty.params);
201 let result_tys = InterfaceType::Tuple(ty.results);
202
203 let mut storage: Storage<'_, Params, Return> = if Params::flatten_count() <= MAX_FLAT_PARAMS {
213 if Return::flatten_count() <= MAX_FLAT_RESULTS {
214 Storage::Direct(slice_to_storage_mut(storage))
215 } else {
216 Storage::ResultsIndirect(slice_to_storage_mut(storage).assume_init_ref())
217 }
218 } else {
219 if Return::flatten_count() <= MAX_FLAT_RESULTS {
220 Storage::ParamsIndirect(slice_to_storage_mut(storage))
221 } else {
222 Storage::Indirect(slice_to_storage_mut(storage).assume_init_ref())
223 }
224 };
225 let mut lift = LiftContext::new(cx.0, &options, types, instance);
226 lift.enter_call();
227 let params = storage.lift_params(&mut lift, param_tys)?;
228
229 let ret = closure(cx.as_context_mut(), params)?;
230 flags.set_may_leave(false);
231 let mut lower = LowerContext::new(cx, &options, types, instance);
232 storage.lower_results(&mut lower, result_tys, ret)?;
233 flags.set_may_leave(true);
234
235 lower.exit_call()?;
236
237 return Ok(());
238
239 enum Storage<'a, P: ComponentType, R: ComponentType> {
240 Direct(&'a mut MaybeUninit<ReturnStack<P::Lower, R::Lower>>),
241 ParamsIndirect(&'a mut MaybeUninit<ReturnStack<ValRaw, R::Lower>>),
242 ResultsIndirect(&'a ReturnPointer<P::Lower>),
243 Indirect(&'a ReturnPointer<ValRaw>),
244 }
245
246 impl<P, R> Storage<'_, P, R>
247 where
248 P: ComponentType + Lift,
249 R: ComponentType + Lower,
250 {
251 unsafe fn lift_params(&self, cx: &mut LiftContext<'_>, ty: InterfaceType) -> Result<P> {
252 match self {
253 Storage::Direct(storage) => P::lift(cx, ty, &storage.assume_init_ref().args),
254 Storage::ResultsIndirect(storage) => P::lift(cx, ty, &storage.args),
255 Storage::ParamsIndirect(storage) => {
256 let ptr = validate_inbounds::<P>(cx.memory(), &storage.assume_init_ref().args)?;
257 P::load(cx, ty, &cx.memory()[ptr..][..P::SIZE32])
258 }
259 Storage::Indirect(storage) => {
260 let ptr = validate_inbounds::<P>(cx.memory(), &storage.args)?;
261 P::load(cx, ty, &cx.memory()[ptr..][..P::SIZE32])
262 }
263 }
264 }
265
266 unsafe fn lower_results<T>(
267 &mut self,
268 cx: &mut LowerContext<'_, T>,
269 ty: InterfaceType,
270 ret: R,
271 ) -> Result<()> {
272 match self {
273 Storage::Direct(storage) => ret.lower(cx, ty, map_maybe_uninit!(storage.ret)),
274 Storage::ParamsIndirect(storage) => {
275 ret.lower(cx, ty, map_maybe_uninit!(storage.ret))
276 }
277 Storage::ResultsIndirect(storage) => {
278 let ptr = validate_inbounds::<R>(cx.as_slice_mut(), &storage.retptr)?;
279 ret.store(cx, ty, ptr)
280 }
281 Storage::Indirect(storage) => {
282 let ptr = validate_inbounds::<R>(cx.as_slice_mut(), &storage.retptr)?;
283 ret.store(cx, ty, ptr)
284 }
285 }
286 }
287 }
288}
289
290fn validate_inbounds<T: ComponentType>(memory: &[u8], ptr: &ValRaw) -> Result<usize> {
291 let ptr = usize::try_from(ptr.get_u32())?;
293 if ptr % usize::try_from(T::ALIGN32)? != 0 {
294 bail!("pointer not aligned");
295 }
296 let end = match ptr.checked_add(T::SIZE32) {
297 Some(n) => n,
298 None => bail!("pointer size overflow"),
299 };
300 if end > memory.len() {
301 bail!("pointer out of bounds")
302 }
303 Ok(ptr)
304}
305
306unsafe fn call_host_and_handle_result<T>(
307 cx: NonNull<VMOpaqueContext>,
308 func: impl FnOnce(
309 *mut ComponentInstance,
310 &Arc<ComponentTypes>,
311 StoreContextMut<'_, T>,
312 ) -> Result<()>,
313) -> bool {
314 let cx = VMComponentContext::from_opaque(cx);
315 let instance = cx.as_ref().instance();
316 let types = (*instance).component_types();
317 let raw_store = (*instance).store();
318 let mut store = StoreContextMut(&mut *raw_store.cast());
319
320 crate::runtime::vm::catch_unwind_and_record_trap(|| {
321 store.0.call_hook(CallHook::CallingHost)?;
322 let res = func(instance, types, store.as_context_mut());
323 store.0.call_hook(CallHook::ReturningFromHost)?;
324 res
325 })
326}
327
328unsafe fn call_host_dynamic<T, F>(
329 instance: *mut ComponentInstance,
330 types: &Arc<ComponentTypes>,
331 mut store: StoreContextMut<'_, T>,
332 ty: TypeFuncIndex,
333 mut flags: InstanceFlags,
334 memory: *mut VMMemoryDefinition,
335 realloc: *mut VMFuncRef,
336 string_encoding: StringEncoding,
337 async_: bool,
338 storage: &mut [MaybeUninit<ValRaw>],
339 closure: F,
340) -> Result<()>
341where
342 F: FnOnce(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()>,
343{
344 if async_ {
345 todo!()
346 }
347
348 let options = Options::new(
349 store.0.id(),
350 NonNull::new(memory),
351 NonNull::new(realloc),
352 string_encoding,
353 );
354
355 if !flags.may_leave() {
359 bail!("cannot leave component instance");
360 }
361
362 let args;
363 let ret_index;
364
365 let func_ty = &types[ty];
366 let param_tys = &types[func_ty.params];
367 let result_tys = &types[func_ty.results];
368 let mut cx = LiftContext::new(store.0, &options, types, instance);
369 cx.enter_call();
370 if let Some(param_count) = param_tys.abi.flat_count(MAX_FLAT_PARAMS) {
371 let mut iter =
373 mem::transmute::<&[MaybeUninit<ValRaw>], &[ValRaw]>(&storage[..param_count]).iter();
374 args = param_tys
375 .types
376 .iter()
377 .map(|ty| Val::lift(&mut cx, *ty, &mut iter))
378 .collect::<Result<Box<[_]>>>()?;
379 ret_index = param_count;
380 assert!(iter.next().is_none());
381 } else {
382 let mut offset =
383 validate_inbounds_dynamic(¶m_tys.abi, cx.memory(), storage[0].assume_init_ref())?;
384 args = param_tys
385 .types
386 .iter()
387 .map(|ty| {
388 let abi = types.canonical_abi(ty);
389 let size = usize::try_from(abi.size32).unwrap();
390 let memory = &cx.memory()[abi.next_field32_size(&mut offset)..][..size];
391 Val::load(&mut cx, *ty, memory)
392 })
393 .collect::<Result<Box<[_]>>>()?;
394 ret_index = 1;
395 };
396
397 let mut result_vals = Vec::with_capacity(result_tys.types.len());
398 for _ in result_tys.types.iter() {
399 result_vals.push(Val::Bool(false));
400 }
401 closure(store.as_context_mut(), &args, &mut result_vals)?;
402 flags.set_may_leave(false);
403
404 let mut cx = LowerContext::new(store, &options, types, instance);
405 if let Some(cnt) = result_tys.abi.flat_count(MAX_FLAT_RESULTS) {
406 let mut dst = storage[..cnt].iter_mut();
407 for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) {
408 val.lower(&mut cx, *ty, &mut dst)?;
409 }
410 assert!(dst.next().is_none());
411 } else {
412 let ret_ptr = storage[ret_index].assume_init_ref();
413 let mut ptr = validate_inbounds_dynamic(&result_tys.abi, cx.as_slice_mut(), ret_ptr)?;
414 for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) {
415 let offset = types.canonical_abi(ty).next_field32_size(&mut ptr);
416 val.store(&mut cx, *ty, offset)?;
417 }
418 }
419
420 flags.set_may_leave(true);
421
422 cx.exit_call()?;
423
424 return Ok(());
425}
426
427fn validate_inbounds_dynamic(abi: &CanonicalAbiInfo, memory: &[u8], ptr: &ValRaw) -> Result<usize> {
428 let ptr = usize::try_from(ptr.get_u32())?;
430 if ptr % usize::try_from(abi.align32)? != 0 {
431 bail!("pointer not aligned");
432 }
433 let end = match ptr.checked_add(usize::try_from(abi.size32).unwrap()) {
434 Some(n) => n,
435 None => bail!("pointer size overflow"),
436 };
437 if end > memory.len() {
438 bail!("pointer out of bounds")
439 }
440 Ok(ptr)
441}
442
443extern "C" fn dynamic_entrypoint<T, F>(
444 cx: NonNull<VMOpaqueContext>,
445 data: NonNull<u8>,
446 ty: u32,
447 _caller_instance: u32,
448 flags: NonNull<VMGlobalDefinition>,
449 memory: *mut VMMemoryDefinition,
450 realloc: *mut VMFuncRef,
451 string_encoding: u8,
452 async_: u8,
453 storage: NonNull<MaybeUninit<ValRaw>>,
454 storage_len: usize,
455) -> bool
456where
457 F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static,
458{
459 let data = data.as_ptr() as *const F;
460 unsafe {
461 call_host_and_handle_result(cx, |instance, types, store| {
462 call_host_dynamic::<T, _>(
463 instance,
464 types,
465 store,
466 TypeFuncIndex::from_u32(ty),
467 InstanceFlags::from_raw(flags),
468 memory,
469 realloc,
470 StringEncoding::from_u8(string_encoding).unwrap(),
471 async_ != 0,
472 NonNull::slice_from_raw_parts(storage, storage_len).as_mut(),
473 |store, params, results| (*data)(store, params, results),
474 )
475 })
476 }
477}