1use crate::component::func::{LiftContext, LowerContext, Options};
2use crate::component::matching::InstanceType;
3use crate::component::storage::slice_to_storage_mut;
4use crate::component::{ComponentNamedList, ComponentType, Instance, Lift, Lower, Val};
5use crate::prelude::*;
6use crate::runtime::vm::component::{
7 ComponentInstance, InstanceFlags, VMComponentContext, VMLowering, VMLoweringCallee,
8};
9use crate::runtime::vm::{VMFuncRef, VMGlobalDefinition, VMMemoryDefinition, VMOpaqueContext};
10use crate::{AsContextMut, CallHook, StoreContextMut, ValRaw};
11use alloc::sync::Arc;
12use core::any::Any;
13use core::mem::{self, MaybeUninit};
14use core::ptr::NonNull;
15use wasmtime_environ::component::{
16 CanonicalAbiInfo, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, StringEncoding,
17 TypeFuncIndex,
18};
19
20pub struct HostFunc {
21 entrypoint: VMLoweringCallee,
22 typecheck: Box<dyn (Fn(TypeFuncIndex, &InstanceType<'_>) -> Result<()>) + Send + Sync>,
23 func: Box<dyn Any + Send + Sync>,
24}
25
26impl core::fmt::Debug for HostFunc {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.debug_struct("HostFunc").finish_non_exhaustive()
29 }
30}
31
32impl HostFunc {
33 pub(crate) fn from_closure<T, F, P, R>(func: F) -> Arc<HostFunc>
34 where
35 F: Fn(StoreContextMut<T>, P) -> Result<R> + Send + Sync + 'static,
36 P: ComponentNamedList + Lift + 'static,
37 R: ComponentNamedList + Lower + 'static,
38 T: 'static,
39 {
40 let entrypoint = Self::entrypoint::<T, F, P, R>;
41 Arc::new(HostFunc {
42 entrypoint,
43 typecheck: Box::new(typecheck::<P, R>),
44 func: Box::new(func),
45 })
46 }
47
48 extern "C" fn entrypoint<T, F, P, R>(
49 cx: NonNull<VMOpaqueContext>,
50 data: NonNull<u8>,
51 ty: u32,
52 _caller_instance: u32,
53 flags: NonNull<VMGlobalDefinition>,
54 memory: *mut VMMemoryDefinition,
55 realloc: *mut VMFuncRef,
56 string_encoding: u8,
57 async_: u8,
58 storage: NonNull<MaybeUninit<ValRaw>>,
59 storage_len: usize,
60 ) -> bool
61 where
62 F: Fn(StoreContextMut<T>, P) -> Result<R>,
63 P: ComponentNamedList + Lift + 'static,
64 R: ComponentNamedList + Lower + 'static,
65 T: 'static,
66 {
67 let data = data.as_ptr() as *const F;
68 unsafe {
69 call_host_and_handle_result::<T>(cx, |store, instance| {
70 call_host(
71 store,
72 instance,
73 TypeFuncIndex::from_u32(ty),
74 InstanceFlags::from_raw(flags),
75 memory,
76 realloc,
77 StringEncoding::from_u8(string_encoding).unwrap(),
78 async_ != 0,
79 NonNull::slice_from_raw_parts(storage, storage_len).as_mut(),
80 |store, args| (*data)(store, args),
81 )
82 })
83 }
84 }
85
86 pub(crate) fn new_dynamic<T, F>(func: F) -> Arc<HostFunc>
87 where
88 F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static,
89 T: 'static,
90 {
91 Arc::new(HostFunc {
92 entrypoint: dynamic_entrypoint::<T, F>,
93 typecheck: Box::new(move |_expected_index, _expected_types| Ok(())),
97 func: Box::new(func),
98 })
99 }
100
101 pub fn typecheck(&self, ty: TypeFuncIndex, types: &InstanceType<'_>) -> Result<()> {
102 (self.typecheck)(ty, types)
103 }
104
105 pub fn lowering(&self) -> VMLowering {
106 let data = NonNull::from(&*self.func).cast();
107 VMLowering {
108 callee: NonNull::new(self.entrypoint as *mut _).unwrap().into(),
109 data: data.into(),
110 }
111 }
112}
113
114fn typecheck<P, R>(ty: TypeFuncIndex, types: &InstanceType<'_>) -> Result<()>
115where
116 P: ComponentNamedList + Lift,
117 R: ComponentNamedList + Lower,
118{
119 let ty = &types.types[ty];
120 P::typecheck(&InterfaceType::Tuple(ty.params), types)
121 .context("type mismatch with parameters")?;
122 R::typecheck(&InterfaceType::Tuple(ty.results), types).context("type mismatch with results")?;
123 Ok(())
124}
125
126unsafe fn call_host<T, Params, Return, F>(
148 mut store: StoreContextMut<'_, T>,
149 instance: Instance,
150 ty: TypeFuncIndex,
151 mut flags: InstanceFlags,
152 memory: *mut VMMemoryDefinition,
153 realloc: *mut VMFuncRef,
154 string_encoding: StringEncoding,
155 async_: bool,
156 storage: &mut [MaybeUninit<ValRaw>],
157 closure: F,
158) -> Result<()>
159where
160 Params: Lift,
161 Return: Lower,
162 F: FnOnce(StoreContextMut<'_, T>, Params) -> Result<Return>,
163{
164 if async_ {
165 todo!()
166 }
167
168 let options = Options::new(
169 store.0.id(),
170 NonNull::new(memory),
171 NonNull::new(realloc),
172 string_encoding,
173 );
174
175 if !flags.may_leave() {
179 bail!("cannot leave component instance");
180 }
181
182 let types = instance.id().get(store.0).component().types().clone();
183 let ty = &types[ty];
184 let param_tys = InterfaceType::Tuple(ty.params);
185 let result_tys = InterfaceType::Tuple(ty.results);
186
187 let mut storage = Storage::<'_, Params, Return>::new_sync(storage);
188 let mut lift = LiftContext::new(store.0, &options, &types, instance);
189 lift.enter_call();
190 let params = storage.lift_params(&mut lift, param_tys)?;
191
192 let ret = closure(store.as_context_mut(), params)?;
193
194 flags.set_may_leave(false);
195 let mut lower = LowerContext::new(store, &options, &types, instance);
196 storage.lower_results(&mut lower, result_tys, ret)?;
197 flags.set_may_leave(true);
198 lower.exit_call()?;
199
200 return Ok(());
201
202 enum Storage<'a, P: ComponentType, R: ComponentType> {
249 PdRd(&'a mut Union<P::Lower, MaybeUninit<R::Lower>>),
254
255 PdRi(&'a Pair<P::Lower, ValRaw>),
260
261 PiRd(&'a mut Union<ValRaw, MaybeUninit<R::Lower>>),
265
266 PiRi(&'a Pair<ValRaw, ValRaw>),
271 }
272
273 #[repr(C)]
276 #[derive(Copy, Clone)]
277 struct Pair<T, U> {
278 a: T,
279 b: U,
280 }
281
282 #[repr(C)]
285 union Union<T: Copy, U: Copy> {
286 a: T,
287 b: U,
288 }
289
290 enum Src<'a, T> {
292 Direct(&'a T),
295
296 Indirect(&'a ValRaw),
299 }
300
301 enum Dst<'a, T> {
303 Direct(&'a mut MaybeUninit<T>),
309
310 Indirect(&'a ValRaw),
316 }
317
318 impl<P, R> Storage<'_, P, R>
319 where
320 P: ComponentType + Lift,
321 R: ComponentType + Lower,
322 {
323 unsafe fn new_sync(storage: &mut [MaybeUninit<ValRaw>]) -> Storage<'_, P, R> {
340 unsafe {
349 if P::flatten_count() <= MAX_FLAT_PARAMS {
350 if R::flatten_count() <= MAX_FLAT_RESULTS {
351 Storage::PdRd(slice_to_storage_mut(storage).assume_init_mut())
352 } else {
353 Storage::PdRi(slice_to_storage_mut(storage).assume_init_ref())
354 }
355 } else {
356 if R::flatten_count() <= MAX_FLAT_RESULTS {
357 Storage::PiRd(slice_to_storage_mut(storage).assume_init_mut())
358 } else {
359 Storage::PiRi(slice_to_storage_mut(storage).assume_init_ref())
360 }
361 }
362 }
363 }
364
365 fn lift_params(&self, cx: &mut LiftContext<'_>, ty: InterfaceType) -> Result<P> {
366 match self.lift_src() {
367 Src::Direct(storage) => P::linear_lift_from_flat(cx, ty, storage),
368 Src::Indirect(ptr) => {
369 let ptr = validate_inbounds::<P>(cx.memory(), ptr)?;
370 P::linear_lift_from_memory(cx, ty, &cx.memory()[ptr..][..P::SIZE32])
371 }
372 }
373 }
374
375 fn lift_src(&self) -> Src<'_, P::Lower> {
376 match self {
377 Storage::PdRd(storage) => unsafe { Src::Direct(&storage.a) },
385 Storage::PdRi(storage) => Src::Direct(&storage.a),
386 Storage::PiRd(storage) => unsafe { Src::Indirect(&storage.a) },
387 Storage::PiRi(storage) => Src::Indirect(&storage.a),
388 }
389 }
390
391 fn lower_results<T>(
392 &mut self,
393 cx: &mut LowerContext<'_, T>,
394 ty: InterfaceType,
395 ret: R,
396 ) -> Result<()> {
397 match self.lower_dst() {
398 Dst::Direct(storage) => ret.linear_lower_to_flat(cx, ty, storage),
399 Dst::Indirect(ptr) => {
400 let ptr = validate_inbounds::<R>(cx.as_slice_mut(), ptr)?;
401 ret.linear_lower_to_memory(cx, ty, ptr)
402 }
403 }
404 }
405
406 fn lower_dst(&mut self) -> Dst<'_, R::Lower> {
407 match self {
408 Storage::PdRd(storage) => unsafe { Dst::Direct(&mut storage.b) },
416 Storage::PiRd(storage) => unsafe { Dst::Direct(&mut storage.b) },
417 Storage::PdRi(storage) => Dst::Indirect(&storage.b),
418 Storage::PiRi(storage) => Dst::Indirect(&storage.b),
419 }
420 }
421 }
422}
423
424fn validate_inbounds<T: ComponentType>(memory: &[u8], ptr: &ValRaw) -> Result<usize> {
425 let ptr = usize::try_from(ptr.get_u32())?;
427 if ptr % usize::try_from(T::ALIGN32)? != 0 {
428 bail!("pointer not aligned");
429 }
430 let end = match ptr.checked_add(T::SIZE32) {
431 Some(n) => n,
432 None => bail!("pointer size overflow"),
433 };
434 if end > memory.len() {
435 bail!("pointer out of bounds")
436 }
437 Ok(ptr)
438}
439
440unsafe fn call_host_and_handle_result<T>(
441 cx: NonNull<VMOpaqueContext>,
442 func: impl FnOnce(StoreContextMut<'_, T>, Instance) -> Result<()>,
443) -> bool
444where
445 T: 'static,
446{
447 let cx = VMComponentContext::from_opaque(cx);
448 ComponentInstance::from_vmctx(cx, |store, instance| {
449 let mut store = store.unchecked_context_mut();
450
451 crate::runtime::vm::catch_unwind_and_record_trap(|| {
452 store.0.call_hook(CallHook::CallingHost)?;
453 let res = func(store.as_context_mut(), instance);
454 store.0.call_hook(CallHook::ReturningFromHost)?;
455 res
456 })
457 })
458}
459
460unsafe fn call_host_dynamic<T, F>(
461 mut store: StoreContextMut<'_, T>,
462 instance: Instance,
463 ty: TypeFuncIndex,
464 mut flags: InstanceFlags,
465 memory: *mut VMMemoryDefinition,
466 realloc: *mut VMFuncRef,
467 string_encoding: StringEncoding,
468 async_: bool,
469 storage: &mut [MaybeUninit<ValRaw>],
470 closure: F,
471) -> Result<()>
472where
473 F: FnOnce(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()>,
474 T: 'static,
475{
476 if async_ {
477 todo!()
478 }
479
480 let options = Options::new(
481 store.0.id(),
482 NonNull::new(memory),
483 NonNull::new(realloc),
484 string_encoding,
485 );
486
487 if !flags.may_leave() {
491 bail!("cannot leave component instance");
492 }
493
494 let args;
495 let ret_index;
496
497 let types = instance.id().get(store.0).component().types().clone();
498 let func_ty = &types[ty];
499 let param_tys = &types[func_ty.params];
500 let result_tys = &types[func_ty.results];
501 let mut cx = LiftContext::new(store.0, &options, &types, instance);
502 cx.enter_call();
503 if let Some(param_count) = param_tys.abi.flat_count(MAX_FLAT_PARAMS) {
504 let mut iter =
506 mem::transmute::<&[MaybeUninit<ValRaw>], &[ValRaw]>(&storage[..param_count]).iter();
507 args = param_tys
508 .types
509 .iter()
510 .map(|ty| Val::lift(&mut cx, *ty, &mut iter))
511 .collect::<Result<Box<[_]>>>()?;
512 ret_index = param_count;
513 assert!(iter.next().is_none());
514 } else {
515 let mut offset =
516 validate_inbounds_dynamic(¶m_tys.abi, cx.memory(), storage[0].assume_init_ref())?;
517 args = param_tys
518 .types
519 .iter()
520 .map(|ty| {
521 let abi = types.canonical_abi(ty);
522 let size = usize::try_from(abi.size32).unwrap();
523 let memory = &cx.memory()[abi.next_field32_size(&mut offset)..][..size];
524 Val::load(&mut cx, *ty, memory)
525 })
526 .collect::<Result<Box<[_]>>>()?;
527 ret_index = 1;
528 };
529
530 let mut result_vals = Vec::with_capacity(result_tys.types.len());
531 for _ in result_tys.types.iter() {
532 result_vals.push(Val::Bool(false));
533 }
534 closure(store.as_context_mut(), &args, &mut result_vals)?;
535 flags.set_may_leave(false);
536
537 let mut cx = LowerContext::new(store, &options, &types, instance);
538 if let Some(cnt) = result_tys.abi.flat_count(MAX_FLAT_RESULTS) {
539 let mut dst = storage[..cnt].iter_mut();
540 for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) {
541 val.lower(&mut cx, *ty, &mut dst)?;
542 }
543 assert!(dst.next().is_none());
544 } else {
545 let ret_ptr = storage[ret_index].assume_init_ref();
546 let mut ptr = validate_inbounds_dynamic(&result_tys.abi, cx.as_slice_mut(), ret_ptr)?;
547 for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) {
548 let offset = types.canonical_abi(ty).next_field32_size(&mut ptr);
549 val.store(&mut cx, *ty, offset)?;
550 }
551 }
552
553 flags.set_may_leave(true);
554
555 cx.exit_call()?;
556
557 return Ok(());
558}
559
560fn validate_inbounds_dynamic(abi: &CanonicalAbiInfo, memory: &[u8], ptr: &ValRaw) -> Result<usize> {
561 let ptr = usize::try_from(ptr.get_u32())?;
563 if ptr % usize::try_from(abi.align32)? != 0 {
564 bail!("pointer not aligned");
565 }
566 let end = match ptr.checked_add(usize::try_from(abi.size32).unwrap()) {
567 Some(n) => n,
568 None => bail!("pointer size overflow"),
569 };
570 if end > memory.len() {
571 bail!("pointer out of bounds")
572 }
573 Ok(ptr)
574}
575
576extern "C" fn dynamic_entrypoint<T, F>(
577 cx: NonNull<VMOpaqueContext>,
578 data: NonNull<u8>,
579 ty: u32,
580 _caller_instance: u32,
581 flags: NonNull<VMGlobalDefinition>,
582 memory: *mut VMMemoryDefinition,
583 realloc: *mut VMFuncRef,
584 string_encoding: u8,
585 async_: u8,
586 storage: NonNull<MaybeUninit<ValRaw>>,
587 storage_len: usize,
588) -> bool
589where
590 F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static,
591 T: 'static,
592{
593 let data = data.as_ptr() as *const F;
594 unsafe {
595 call_host_and_handle_result(cx, |store, instance| {
596 call_host_dynamic::<T, _>(
597 store,
598 instance,
599 TypeFuncIndex::from_u32(ty),
600 InstanceFlags::from_raw(flags),
601 memory,
602 realloc,
603 StringEncoding::from_u8(string_encoding).unwrap(),
604 async_ != 0,
605 NonNull::slice_from_raw_parts(storage, storage_len).as_mut(),
606 |store, params, results| (*data)(store, params, results),
607 )
608 })
609 }
610}