wasmtime_test_util/
component_fuzz.rs

1//! This module generates test cases for the Wasmtime component model function APIs,
2//! e.g. `wasmtime::component::func::Func` and `TypedFunc`.
3//!
4//! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a
5//! result, and a component which exports a function and imports a function.  The exported function forwards its
6//! parameters to the imported one and forwards the result back to the caller.  This serves to exercise Wasmtime's
7//! lifting and lowering code and verify the values remain intact during both processes.
8
9use arbitrary::{Arbitrary, Unstructured};
10use indexmap::IndexSet;
11use proc_macro2::{Ident, TokenStream};
12use quote::{ToTokens, format_ident, quote};
13use std::borrow::Cow;
14use std::fmt::{self, Debug, Write};
15use std::hash::{Hash, Hasher};
16use std::iter;
17use std::ops::Deref;
18use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE};
19
20const MAX_FLAT_PARAMS: usize = 16;
21const MAX_FLAT_ASYNC_PARAMS: usize = 4;
22const MAX_FLAT_RESULTS: usize = 1;
23
24/// The name of the imported host function which the generated component will call
25pub const IMPORT_FUNCTION: &str = "echo-import";
26
27/// The name of the exported guest function which the host should call
28pub const EXPORT_FUNCTION: &str = "echo-export";
29
30/// Wasmtime allows up to 100 type depth so limit this to just under that.
31pub const MAX_TYPE_DEPTH: u32 = 99;
32
33macro_rules! uwriteln {
34    ($($arg:tt)*) => {
35        writeln!($($arg)*).unwrap()
36    };
37}
38
39macro_rules! uwrite {
40    ($($arg:tt)*) => {
41        write!($($arg)*).unwrap()
42    };
43}
44
45#[derive(Debug, Copy, Clone, PartialEq, Eq)]
46enum CoreType {
47    I32,
48    I64,
49    F32,
50    F64,
51}
52
53impl CoreType {
54    /// This is the `join` operation specified in [the canonical
55    /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for
56    /// variant types.
57    fn join(self, other: Self) -> Self {
58        match (self, other) {
59            _ if self == other => self,
60            (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32,
61            _ => Self::I64,
62        }
63    }
64}
65
66impl fmt::Display for CoreType {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            Self::I32 => f.write_str("i32"),
70            Self::I64 => f.write_str("i64"),
71            Self::F32 => f.write_str("f32"),
72            Self::F64 => f.write_str("f64"),
73        }
74    }
75}
76
77/// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than
78/// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl
79#[derive(Debug, Clone)]
80pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>);
81
82impl<T, const L: u32, const H: u32> VecInRange<T, L, H> {
83    fn new<'a>(
84        input: &mut Unstructured<'a>,
85        fuel: &mut u32,
86        generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>,
87    ) -> arbitrary::Result<Self> {
88        let mut ret = Vec::new();
89        input.arbitrary_loop(Some(L), Some(H), |input| {
90            if *fuel > 0 {
91                *fuel = *fuel - 1;
92                ret.push(generate(input, fuel)?);
93                Ok(std::ops::ControlFlow::Continue(()))
94            } else {
95                Ok(std::ops::ControlFlow::Break(()))
96            }
97        })?;
98        Ok(Self(ret))
99    }
100}
101
102impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> {
103    type Target = [T];
104
105    fn deref(&self) -> &[T] {
106        self.0.deref()
107    }
108}
109
110/// Represents a component model interface type
111#[expect(missing_docs, reason = "self-describing")]
112#[derive(Debug, Clone)]
113pub enum Type {
114    Bool,
115    S8,
116    U8,
117    S16,
118    U16,
119    S32,
120    U32,
121    S64,
122    U64,
123    Float32,
124    Float64,
125    Char,
126    String,
127    List(Box<Type>),
128
129    // Give records the ability to generate a generous amount of fields but
130    // don't let the fuzzer go too wild since `wasmparser`'s validator currently
131    // has hard limits in the 1000-ish range on the number of fields a record
132    // may contain.
133    Record(VecInRange<Type, 1, 200>),
134
135    // Tuples can only have up to 16 type parameters in wasmtime right now for
136    // the static API, but the standard library only supports `Debug` up to 11
137    // elements, so compromise at an even 10.
138    Tuple(VecInRange<Type, 1, 10>),
139
140    // Like records, allow a good number of variants, but variants require at
141    // least one case.
142    Variant(VecInRange<Option<Type>, 1, 200>),
143    Enum(u32),
144
145    Option(Box<Type>),
146    Result {
147        ok: Option<Box<Type>>,
148        err: Option<Box<Type>>,
149    },
150
151    Flags(u32),
152}
153
154impl Type {
155    pub fn generate(
156        u: &mut Unstructured<'_>,
157        depth: u32,
158        fuel: &mut u32,
159    ) -> arbitrary::Result<Type> {
160        *fuel = fuel.saturating_sub(1);
161        let max = if depth == 0 || *fuel == 0 { 12 } else { 20 };
162        Ok(match u.int_in_range(0..=max)? {
163            0 => Type::Bool,
164            1 => Type::S8,
165            2 => Type::U8,
166            3 => Type::S16,
167            4 => Type::U16,
168            5 => Type::S32,
169            6 => Type::U32,
170            7 => Type::S64,
171            8 => Type::U64,
172            9 => Type::Float32,
173            10 => Type::Float64,
174            11 => Type::Char,
175            12 => Type::String,
176            // ^-- if you add something here update the `depth == 0` case above
177            13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)),
178            14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?),
179            15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?),
180            16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| {
181                Type::generate_opt(u, depth - 1, fuel)
182            })?),
183            17 => {
184                let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
185                *fuel -= amt;
186                Type::Enum(amt)
187            }
188            18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)),
189            19 => Type::Result {
190                ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
191                err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
192            },
193            20 => {
194                let amt = u.int_in_range(1..=(*fuel).min(32))?;
195                *fuel -= amt;
196                Type::Flags(amt)
197            }
198            // ^-- if you add something here update the `depth != 0` case above
199            _ => unreachable!(),
200        })
201    }
202
203    fn generate_opt(
204        u: &mut Unstructured<'_>,
205        depth: u32,
206        fuel: &mut u32,
207    ) -> arbitrary::Result<Option<Type>> {
208        Ok(if u.arbitrary()? {
209            Some(Type::generate(u, depth, fuel)?)
210        } else {
211            None
212        })
213    }
214
215    fn generate_list<const L: u32, const H: u32>(
216        u: &mut Unstructured<'_>,
217        depth: u32,
218        fuel: &mut u32,
219    ) -> arbitrary::Result<VecInRange<Type, L, H>> {
220        VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel))
221    }
222
223    /// Generates text format wasm into `s` to store a value of this type, in
224    /// its flat representation stored in the `locals` provided, to the local
225    /// named `ptr` at the `offset` provided.
226    ///
227    /// This will register helper functions necessary in `helpers`. The
228    /// `locals` iterator will be advanced for all locals consumed by this
229    /// store operation.
230    fn store_flat<'a>(
231        &'a self,
232        s: &mut String,
233        ptr: &str,
234        offset: u32,
235        locals: &mut dyn Iterator<Item = FlatSource>,
236        helpers: &mut IndexSet<Helper<'a>>,
237    ) {
238        enum Kind {
239            Primitive(&'static str),
240            PointerPair,
241            Helper,
242        }
243        let kind = match self {
244            Type::Bool | Type::S8 | Type::U8 => Kind::Primitive("i32.store8"),
245            Type::S16 | Type::U16 => Kind::Primitive("i32.store16"),
246            Type::S32 | Type::U32 | Type::Char => Kind::Primitive("i32.store"),
247            Type::S64 | Type::U64 => Kind::Primitive("i64.store"),
248            Type::Float32 => Kind::Primitive("f32.store"),
249            Type::Float64 => Kind::Primitive("f64.store"),
250            Type::String | Type::List(_) => Kind::PointerPair,
251            Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.store8"),
252            Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.store16"),
253            Type::Enum(_) => Kind::Primitive("i32.store"),
254            Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.store8"),
255            Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.store16"),
256            Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.store"),
257            Type::Flags(_) => unreachable!(),
258            Type::Record(_)
259            | Type::Tuple(_)
260            | Type::Variant(_)
261            | Type::Option(_)
262            | Type::Result { .. } => Kind::Helper,
263        };
264
265        match kind {
266            Kind::Primitive(op) => uwriteln!(
267                s,
268                "({op} offset={offset} (local.get {ptr}) {})",
269                locals.next().unwrap()
270            ),
271            Kind::PointerPair => {
272                let abi_ptr = locals.next().unwrap();
273                let abi_len = locals.next().unwrap();
274                uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_ptr})",);
275                let offset = offset + 4;
276                uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_len})",);
277            }
278            Kind::Helper => {
279                let (index, _) = helpers.insert_full(Helper(self));
280                uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
281                for _ in 0..self.lowered().len() {
282                    let i = locals.next().unwrap();
283                    uwriteln!(s, "{i}");
284                }
285                uwriteln!(s, "call $store_helper_{index}");
286            }
287        }
288    }
289
290    /// Generates a text-format wasm function which takes a pointer and this
291    /// type's flat representation as arguments and then stores this value in
292    /// the first argument.
293    ///
294    /// This is used to store records/variants to cut down on the size of final
295    /// functions and make codegen here a bit easier.
296    fn store_flat_helper<'a>(
297        &'a self,
298        s: &mut String,
299        i: usize,
300        helpers: &mut IndexSet<Helper<'a>>,
301    ) {
302        uwrite!(s, "(func $store_helper_{i} (param i32)");
303        let lowered = self.lowered();
304        for ty in &lowered {
305            uwrite!(s, " (param {ty})");
306        }
307        s.push_str("\n");
308        let locals = (0..lowered.len() as u32).map(|i| i + 1).collect::<Vec<_>>();
309        let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
310            let mut locals = locals.iter().cloned().map(FlatSource::Local);
311            for (offset, ty) in record_field_offsets(types) {
312                ty.store_flat(s, "0", offset, &mut locals, helpers);
313            }
314            assert!(locals.next().is_none());
315        };
316        let variant = |s: &mut String,
317                       helpers: &mut IndexSet<Helper<'a>>,
318                       types: &[Option<&'a Type>]| {
319            let (size, offset) = variant_memory_info(types.iter().cloned());
320            // One extra block for out-of-bounds discriminants.
321            for _ in 0..types.len() + 1 {
322                s.push_str("block\n");
323            }
324
325            // Store the discriminant in memory, then branch on it to figure
326            // out which case we're in.
327            let store = match size {
328                DiscriminantSize::Size1 => "i32.store8",
329                DiscriminantSize::Size2 => "i32.store16",
330                DiscriminantSize::Size4 => "i32.store",
331            };
332            uwriteln!(s, "({store} (local.get 0) (local.get 1))");
333            s.push_str("local.get 1\n");
334            s.push_str("br_table");
335            for i in 0..types.len() + 1 {
336                uwrite!(s, " {i}");
337            }
338            s.push_str("\nend\n");
339
340            // Store each payload individually while converting locals from
341            // their source types to the precise type necessary for this
342            // variant.
343            for ty in types {
344                if let Some(ty) = ty {
345                    let ty_lowered = ty.lowered();
346                    let mut locals = locals[1..].iter().zip(&lowered[1..]).zip(&ty_lowered).map(
347                        |((i, from), to)| FlatSource::LocalConvert {
348                            local: *i,
349                            from: *from,
350                            to: *to,
351                        },
352                    );
353                    ty.store_flat(s, "0", offset, &mut locals, helpers);
354                }
355                s.push_str("return\n");
356                s.push_str("end\n");
357            }
358
359            // Catch-all result which is for out-of-bounds discriminants.
360            s.push_str("unreachable\n");
361        };
362        match self {
363            Type::Bool
364            | Type::S8
365            | Type::U8
366            | Type::S16
367            | Type::U16
368            | Type::S32
369            | Type::U32
370            | Type::Char
371            | Type::S64
372            | Type::U64
373            | Type::Float32
374            | Type::Float64
375            | Type::String
376            | Type::List(_)
377            | Type::Flags(_)
378            | Type::Enum(_) => unreachable!(),
379
380            Type::Record(r) => record(s, helpers, r),
381            Type::Tuple(t) => record(s, helpers, t),
382            Type::Variant(v) => variant(
383                s,
384                helpers,
385                &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
386            ),
387            Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
388            Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
389        };
390        s.push_str(")\n");
391    }
392
393    /// Same as `store_flat`, except loads the flat values from `ptr+offset`.
394    ///
395    /// Results are placed directly on the wasm stack.
396    fn load_flat<'a>(
397        &'a self,
398        s: &mut String,
399        ptr: &str,
400        offset: u32,
401        helpers: &mut IndexSet<Helper<'a>>,
402    ) {
403        enum Kind {
404            Primitive(&'static str),
405            PointerPair,
406            Helper,
407        }
408        let kind = match self {
409            Type::Bool | Type::U8 => Kind::Primitive("i32.load8_u"),
410            Type::S8 => Kind::Primitive("i32.load8_s"),
411            Type::U16 => Kind::Primitive("i32.load16_u"),
412            Type::S16 => Kind::Primitive("i32.load16_s"),
413            Type::U32 | Type::S32 | Type::Char => Kind::Primitive("i32.load"),
414            Type::U64 | Type::S64 => Kind::Primitive("i64.load"),
415            Type::Float32 => Kind::Primitive("f32.load"),
416            Type::Float64 => Kind::Primitive("f64.load"),
417            Type::String | Type::List(_) => Kind::PointerPair,
418            Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.load8_u"),
419            Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.load16_u"),
420            Type::Enum(_) => Kind::Primitive("i32.load"),
421            Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.load8_u"),
422            Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.load16_u"),
423            Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.load"),
424            Type::Flags(_) => unreachable!(),
425
426            Type::Record(_)
427            | Type::Tuple(_)
428            | Type::Variant(_)
429            | Type::Option(_)
430            | Type::Result { .. } => Kind::Helper,
431        };
432        match kind {
433            Kind::Primitive(op) => uwriteln!(s, "({op} offset={offset} (local.get {ptr}))"),
434            Kind::PointerPair => {
435                uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
436                let offset = offset + 4;
437                uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
438            }
439            Kind::Helper => {
440                let (index, _) = helpers.insert_full(Helper(self));
441                uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
442                uwriteln!(s, "call $load_helper_{index}");
443            }
444        }
445    }
446
447    /// Same as `store_flat_helper` but for loading the flat representation.
448    fn load_flat_helper<'a>(
449        &'a self,
450        s: &mut String,
451        i: usize,
452        helpers: &mut IndexSet<Helper<'a>>,
453    ) {
454        uwrite!(s, "(func $load_helper_{i} (param i32)");
455        let lowered = self.lowered();
456        for ty in &lowered {
457            uwrite!(s, " (result {ty})");
458        }
459        s.push_str("\n");
460        let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
461            for (offset, ty) in record_field_offsets(types) {
462                ty.load_flat(s, "0", offset, helpers);
463            }
464        };
465        let variant = |s: &mut String,
466                       helpers: &mut IndexSet<Helper<'a>>,
467                       types: &[Option<&'a Type>]| {
468            let (size, offset) = variant_memory_info(types.iter().cloned());
469
470            // Destination locals where the flat representation will be stored.
471            // These are automatically zero which handles unused fields too.
472            for (i, ty) in lowered.iter().enumerate() {
473                uwriteln!(s, " (local $r{i} {ty})");
474            }
475
476            // Return block each case jumps to after setting all locals.
477            s.push_str("block $r\n");
478
479            // One extra block for "out of bounds discriminant".
480            for _ in 0..types.len() + 1 {
481                s.push_str("block\n");
482            }
483
484            // Load the discriminant and branch on it, storing it in
485            // `$r0` as well which is the first flat local representation.
486            let load = match size {
487                DiscriminantSize::Size1 => "i32.load8_u",
488                DiscriminantSize::Size2 => "i32.load16",
489                DiscriminantSize::Size4 => "i32.load",
490            };
491            uwriteln!(s, "({load} (local.get 0))");
492            s.push_str("local.tee $r0\n");
493            s.push_str("br_table");
494            for i in 0..types.len() + 1 {
495                uwrite!(s, " {i}");
496            }
497            s.push_str("\nend\n");
498
499            // For each payload, which is in its own block, load payloads from
500            // memory as necessary and convert them into the final locals.
501            for ty in types {
502                if let Some(ty) = ty {
503                    let ty_lowered = ty.lowered();
504                    ty.load_flat(s, "0", offset, helpers);
505                    for (i, (from, to)) in ty_lowered.iter().zip(&lowered[1..]).enumerate().rev() {
506                        let i = i + 1;
507                        match (from, to) {
508                            (CoreType::F32, CoreType::I32) => {
509                                s.push_str("i32.reinterpret_f32\n");
510                            }
511                            (CoreType::I32, CoreType::I64) => {
512                                s.push_str("i64.extend_i32_u\n");
513                            }
514                            (CoreType::F32, CoreType::I64) => {
515                                s.push_str("i32.reinterpret_f32\n");
516                                s.push_str("i64.extend_i32_u\n");
517                            }
518                            (CoreType::F64, CoreType::I64) => {
519                                s.push_str("i64.reinterpret_f64\n");
520                            }
521                            (a, b) if a == b => {}
522                            _ => unimplemented!("convert {from:?} to {to:?}"),
523                        }
524                        uwriteln!(s, "local.set $r{i}");
525                    }
526                }
527                s.push_str("br $r\n");
528                s.push_str("end\n");
529            }
530
531            // The catch-all block for out-of-bounds discriminants.
532            s.push_str("unreachable\n");
533            s.push_str("end\n");
534            for i in 0..lowered.len() {
535                uwriteln!(s, " local.get $r{i}");
536            }
537        };
538
539        match self {
540            Type::Bool
541            | Type::S8
542            | Type::U8
543            | Type::S16
544            | Type::U16
545            | Type::S32
546            | Type::U32
547            | Type::Char
548            | Type::S64
549            | Type::U64
550            | Type::Float32
551            | Type::Float64
552            | Type::String
553            | Type::List(_)
554            | Type::Flags(_)
555            | Type::Enum(_) => unreachable!(),
556
557            Type::Record(r) => record(s, helpers, r),
558            Type::Tuple(t) => record(s, helpers, t),
559            Type::Variant(v) => variant(
560                s,
561                helpers,
562                &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
563            ),
564            Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
565            Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
566        };
567        s.push_str(")\n");
568    }
569}
570
571#[derive(Clone)]
572enum FlatSource {
573    Local(u32),
574    LocalConvert {
575        local: u32,
576        from: CoreType,
577        to: CoreType,
578    },
579}
580
581impl fmt::Display for FlatSource {
582    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
583        match self {
584            FlatSource::Local(i) => write!(f, "(local.get {i})"),
585            FlatSource::LocalConvert { local, from, to } => {
586                match (from, to) {
587                    (a, b) if a == b => write!(f, "(local.get {local})"),
588                    (CoreType::I32, CoreType::F32) => {
589                        write!(f, "(f32.reinterpret_i32 (local.get {local}))")
590                    }
591                    (CoreType::I64, CoreType::I32) => {
592                        write!(f, "(i32.wrap_i64 (local.get {local}))")
593                    }
594                    (CoreType::I64, CoreType::F64) => {
595                        write!(f, "(f64.reinterpret_i64 (local.get {local}))")
596                    }
597                    (CoreType::I64, CoreType::F32) => {
598                        write!(
599                            f,
600                            "(f32.reinterpret_i32 (i32.wrap_i64 (local.get {local})))"
601                        )
602                    }
603                    _ => unimplemented!("convert {from:?} to {to:?}"),
604                }
605                // ..
606            }
607        }
608    }
609}
610
611fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) {
612    for ty in types {
613        ty.lower(vec);
614    }
615}
616
617fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) {
618    vec.push(CoreType::I32);
619    let offset = vec.len();
620    for ty in types {
621        let ty = match ty {
622            Some(ty) => ty,
623            None => continue,
624        };
625        for (index, ty) in ty.lowered().iter().enumerate() {
626            let index = offset + index;
627            if index < vec.len() {
628                vec[index] = vec[index].join(*ty);
629            } else {
630                vec.push(*ty)
631            }
632        }
633    }
634}
635
636fn u32_count_from_flag_count(count: usize) -> usize {
637    match FlagsSize::from_count(count) {
638        FlagsSize::Size0 => 0,
639        FlagsSize::Size1 | FlagsSize::Size2 => 1,
640        FlagsSize::Size4Plus(n) => n.into(),
641    }
642}
643
644struct SizeAndAlignment {
645    size: usize,
646    alignment: u32,
647}
648
649impl Type {
650    fn lowered(&self) -> Vec<CoreType> {
651        let mut vec = Vec::new();
652        self.lower(&mut vec);
653        vec
654    }
655
656    fn lower(&self, vec: &mut Vec<CoreType>) {
657        match self {
658            Type::Bool
659            | Type::U8
660            | Type::S8
661            | Type::S16
662            | Type::U16
663            | Type::S32
664            | Type::U32
665            | Type::Char
666            | Type::Enum(_) => vec.push(CoreType::I32),
667            Type::S64 | Type::U64 => vec.push(CoreType::I64),
668            Type::Float32 => vec.push(CoreType::F32),
669            Type::Float64 => vec.push(CoreType::F64),
670            Type::String | Type::List(_) => {
671                vec.push(CoreType::I32);
672                vec.push(CoreType::I32);
673            }
674            Type::Record(types) => lower_record(types.iter(), vec),
675            Type::Tuple(types) => lower_record(types.0.iter(), vec),
676            Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec),
677            Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec),
678            Type::Result { ok, err } => {
679                lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec)
680            }
681            Type::Flags(count) => vec.extend(
682                iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)),
683            ),
684        }
685    }
686
687    fn size_and_alignment(&self) -> SizeAndAlignment {
688        match self {
689            Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
690                size: 1,
691                alignment: 1,
692            },
693
694            Type::S16 | Type::U16 => SizeAndAlignment {
695                size: 2,
696                alignment: 2,
697            },
698
699            Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
700                size: 4,
701                alignment: 4,
702            },
703
704            Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
705                size: 8,
706                alignment: 8,
707            },
708
709            Type::String | Type::List(_) => SizeAndAlignment {
710                size: 8,
711                alignment: 4,
712            },
713
714            Type::Record(types) => record_size_and_alignment(types.iter()),
715
716            Type::Tuple(types) => record_size_and_alignment(types.0.iter()),
717
718            Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())),
719
720            Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)),
721
722            Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()),
723
724            Type::Result { ok, err } => {
725                variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter())
726            }
727
728            Type::Flags(count) => match FlagsSize::from_count(*count as usize) {
729                FlagsSize::Size0 => SizeAndAlignment {
730                    size: 0,
731                    alignment: 1,
732                },
733                FlagsSize::Size1 => SizeAndAlignment {
734                    size: 1,
735                    alignment: 1,
736                },
737                FlagsSize::Size2 => SizeAndAlignment {
738                    size: 2,
739                    alignment: 2,
740                },
741                FlagsSize::Size4Plus(n) => SizeAndAlignment {
742                    size: usize::from(n) * 4,
743                    alignment: 4,
744                },
745            },
746        }
747    }
748}
749
750fn align_to(a: usize, align: u32) -> usize {
751    let align = align as usize;
752    (a + (align - 1)) & !(align - 1)
753}
754
755fn record_field_offsets<'a>(
756    types: impl IntoIterator<Item = &'a Type>,
757) -> impl Iterator<Item = (u32, &'a Type)> {
758    let mut offset = 0;
759    types.into_iter().map(move |ty| {
760        let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
761        let ret = align_to(offset, alignment);
762        offset = ret + size;
763        (ret as u32, ty)
764    })
765}
766
767fn record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment {
768    let mut offset = 0;
769    let mut align = 1;
770    for ty in types {
771        let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
772        offset = align_to(offset, alignment) + size;
773        align = align.max(alignment);
774    }
775
776    SizeAndAlignment {
777        size: align_to(offset, align),
778        alignment: align,
779    }
780}
781
782fn variant_size_and_alignment<'a>(
783    types: impl ExactSizeIterator<Item = Option<&'a Type>>,
784) -> SizeAndAlignment {
785    let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
786    let mut alignment = u32::from(discriminant_size);
787    let mut size = 0;
788    for ty in types {
789        if let Some(ty) = ty {
790            let size_and_alignment = ty.size_and_alignment();
791            alignment = alignment.max(size_and_alignment.alignment);
792            size = size.max(size_and_alignment.size);
793        }
794    }
795
796    SizeAndAlignment {
797        size: align_to(
798            align_to(usize::from(discriminant_size), alignment) + size,
799            alignment,
800        ),
801        alignment,
802    }
803}
804
805fn variant_memory_info<'a>(
806    types: impl ExactSizeIterator<Item = Option<&'a Type>>,
807) -> (DiscriminantSize, u32) {
808    let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
809    let mut alignment = u32::from(discriminant_size);
810    for ty in types {
811        if let Some(ty) = ty {
812            let size_and_alignment = ty.size_and_alignment();
813            alignment = alignment.max(size_and_alignment.alignment);
814        }
815    }
816
817    (
818        discriminant_size,
819        align_to(usize::from(discriminant_size), alignment) as u32,
820    )
821}
822
823/// Generates the internals of a core wasm module which imports a single
824/// component function `IMPORT_FUNCTION` and exports a single component
825/// function `EXPORT_FUNCTION`.
826///
827/// The component function takes `params` as arguments and optionally returns
828/// `result`. The `lift_abi` and `lower_abi` fields indicate the ABI in-use for
829/// this operation.
830fn make_import_and_export(
831    params: &[&Type],
832    result: Option<&Type>,
833    lift_abi: LiftAbi,
834    lower_abi: LowerAbi,
835) -> String {
836    let params_lowered = params
837        .iter()
838        .flat_map(|ty| ty.lowered())
839        .collect::<Box<[_]>>();
840    let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new());
841
842    let mut wat = String::new();
843
844    enum Location {
845        Flat,
846        Indirect(u32),
847    }
848
849    // Generate the core wasm type corresponding to the imported function being
850    // lowered with `lower_abi`.
851    wat.push_str(&format!("(type $import (func"));
852    let max_import_params = match lower_abi {
853        LowerAbi::Sync => MAX_FLAT_PARAMS,
854        LowerAbi::Async => MAX_FLAT_ASYNC_PARAMS,
855    };
856    let (import_params_loc, nparams) = push_params(&mut wat, &params_lowered, max_import_params);
857    let import_results_loc = match lower_abi {
858        LowerAbi::Sync => {
859            push_result_or_retptr(&mut wat, &result_lowered, nparams, MAX_FLAT_RESULTS)
860        }
861        LowerAbi::Async => {
862            let loc = if result.is_none() {
863                Location::Flat
864            } else {
865                wat.push_str(" (param i32)"); // result pointer
866                Location::Indirect(nparams)
867            };
868            wat.push_str(" (result i32)"); // status code
869            loc
870        }
871    };
872    wat.push_str("))\n");
873
874    // Generate the import function.
875    wat.push_str(&format!(
876        r#"(import "host" "{IMPORT_FUNCTION}" (func $host (type $import)))"#
877    ));
878
879    // Do the same as above for the exported function's type which is lifted
880    // with `lift_abi`.
881    //
882    // Note that `export_results_loc` being `None` means that `task.return` is
883    // used to communicate results.
884    wat.push_str(&format!("(type $export (func"));
885    let (export_params_loc, _nparams) = push_params(&mut wat, &params_lowered, MAX_FLAT_PARAMS);
886    let export_results_loc = match lift_abi {
887        LiftAbi::Sync => Some(push_group(&mut wat, "result", &result_lowered, MAX_FLAT_RESULTS).0),
888        LiftAbi::AsyncCallback => {
889            wat.push_str(" (result i32)"); // status code
890            None
891        }
892        LiftAbi::AsyncStackful => None,
893    };
894    wat.push_str("))\n");
895
896    // If the export is async, generate `task.return` as an import as well
897    // which is necesary to communicate the results.
898    if export_results_loc.is_none() {
899        wat.push_str(&format!("(type $task.return (func"));
900        push_params(&mut wat, &result_lowered, MAX_FLAT_PARAMS);
901        wat.push_str("))\n");
902        wat.push_str(&format!(
903            r#"(import "" "task.return" (func $task.return (type $task.return)))"#
904        ));
905    }
906
907    wat.push_str(&format!(
908        r#"
909(func (export "{EXPORT_FUNCTION}") (type $export)
910    (local $retptr i32)
911    (local $argptr i32)
912        "#
913    ));
914    let mut store_helpers = IndexSet::new();
915    let mut load_helpers = IndexSet::new();
916
917    match (export_params_loc, import_params_loc) {
918        // flat => flat is just moving locals around
919        (Location::Flat, Location::Flat) => {
920            for (index, _) in params_lowered.iter().enumerate() {
921                uwrite!(wat, "local.get {index}\n");
922            }
923        }
924
925        // indirect => indirect is just moving locals around
926        (Location::Indirect(i), Location::Indirect(j)) => {
927            assert_eq!(j, 0);
928            uwrite!(wat, "local.get {i}\n");
929        }
930
931        // flat => indirect means that all parameters are stored in memory as
932        // if it was a record of all the parameters.
933        (Location::Flat, Location::Indirect(_)) => {
934            let SizeAndAlignment { size, alignment } =
935                record_size_and_alignment(params.iter().cloned());
936            wat.push_str(&format!(
937                r#"
938                    (local.set $argptr
939                        (call $realloc
940                            (i32.const 0)
941                            (i32.const 0)
942                            (i32.const {alignment})
943                            (i32.const {size})))
944                    local.get $argptr
945                "#
946            ));
947            let mut locals = (0..params_lowered.len() as u32).map(FlatSource::Local);
948            for (offset, ty) in record_field_offsets(params.iter().cloned()) {
949                ty.store_flat(&mut wat, "$argptr", offset, &mut locals, &mut store_helpers);
950            }
951            assert!(locals.next().is_none());
952        }
953
954        (Location::Indirect(_), Location::Flat) => unreachable!(),
955    }
956
957    // Pass a return-pointer if necessary.
958    match import_results_loc {
959        Location::Flat => {}
960        Location::Indirect(_) => {
961            let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment();
962
963            wat.push_str(&format!(
964                r#"
965                    (local.set $retptr
966                        (call $realloc
967                            (i32.const 0)
968                            (i32.const 0)
969                            (i32.const {alignment})
970                            (i32.const {size})))
971                    local.get $retptr
972                "#
973            ));
974        }
975    }
976
977    wat.push_str("call $host\n");
978
979    // Assert the lowered call is ready if an async code was returned.
980    //
981    // TODO: handle when the import isn't ready yet
982    if let LowerAbi::Async = lower_abi {
983        wat.push_str("i32.const 2\n");
984        wat.push_str("i32.ne\n");
985        wat.push_str("if unreachable end\n");
986    }
987
988    // TODO: conditionally inject a yield here
989
990    match (import_results_loc, export_results_loc) {
991        // flat => flat results involves nothing, the results are already on
992        // the stack.
993        (Location::Flat, Some(Location::Flat)) => {}
994
995        // indirect => indirect results requires returning the `$retptr` the
996        // host call filled in.
997        (Location::Indirect(_), Some(Location::Indirect(_))) => {
998            wat.push_str("local.get $retptr\n");
999        }
1000
1001        // indirect => flat requires loading the result from the return pointer
1002        (Location::Indirect(_), Some(Location::Flat)) => {
1003            result
1004                .unwrap()
1005                .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1006        }
1007
1008        // flat => task.return is easy, the results are already there so just
1009        // call the function.
1010        (Location::Flat, None) => {
1011            wat.push_str("call $task.return\n");
1012        }
1013
1014        // indirect => task.return needs to forward `$retptr` if the results
1015        // are indirect, or otherwise it must be loaded from memory to a flat
1016        // representation.
1017        (Location::Indirect(_), None) => {
1018            if result_lowered.len() <= MAX_FLAT_PARAMS {
1019                result
1020                    .unwrap()
1021                    .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1022            } else {
1023                wat.push_str("local.get $retptr\n");
1024            }
1025            wat.push_str("call $task.return\n");
1026        }
1027
1028        (Location::Flat, Some(Location::Indirect(_))) => unreachable!(),
1029    }
1030
1031    if let LiftAbi::AsyncCallback = lift_abi {
1032        wat.push_str("i32.const 0\n"); // completed status code
1033    }
1034
1035    wat.push_str(")\n");
1036
1037    // Generate a `callback` function for the callback ABI.
1038    //
1039    // TODO: fill this in
1040    if let LiftAbi::AsyncCallback = lift_abi {
1041        wat.push_str(
1042            r#"
1043(func (export "callback") (param i32 i32 i32) (result i32) unreachable)
1044            "#,
1045        );
1046    }
1047
1048    // Fill out all store/load helpers that were needed during generation
1049    // above. This is a fix-point-loop since each helper may end up requiring
1050    // more helpers.
1051    let mut i = 0;
1052    while i < store_helpers.len() {
1053        let ty = store_helpers[i].0;
1054        ty.store_flat_helper(&mut wat, i, &mut store_helpers);
1055        i += 1;
1056    }
1057    i = 0;
1058    while i < load_helpers.len() {
1059        let ty = load_helpers[i].0;
1060        ty.load_flat_helper(&mut wat, i, &mut load_helpers);
1061        i += 1;
1062    }
1063
1064    return wat;
1065
1066    fn push_params(wat: &mut String, params: &[CoreType], max_flat: usize) -> (Location, u32) {
1067        push_group(wat, "param", params, max_flat)
1068    }
1069
1070    fn push_group(
1071        wat: &mut String,
1072        name: &str,
1073        params: &[CoreType],
1074        max_flat: usize,
1075    ) -> (Location, u32) {
1076        let mut nparams = 0;
1077        let loc = if params.is_empty() {
1078            // nothing to emit...
1079            Location::Flat
1080        } else if params.len() <= max_flat {
1081            wat.push_str(&format!(" ({name}"));
1082            for ty in params {
1083                wat.push_str(&format!(" {ty}"));
1084                nparams += 1;
1085            }
1086            wat.push_str(")");
1087            Location::Flat
1088        } else {
1089            wat.push_str(&format!(" ({name} i32)"));
1090            nparams += 1;
1091            Location::Indirect(0)
1092        };
1093        (loc, nparams)
1094    }
1095
1096    fn push_result_or_retptr(
1097        wat: &mut String,
1098        results: &[CoreType],
1099        nparams: u32,
1100        max_flat: usize,
1101    ) -> Location {
1102        if results.is_empty() {
1103            // nothing to emit...
1104            Location::Flat
1105        } else if results.len() <= max_flat {
1106            wat.push_str(" (result");
1107            for ty in results {
1108                wat.push_str(&format!(" {ty}"));
1109            }
1110            wat.push_str(")");
1111            Location::Flat
1112        } else {
1113            wat.push_str(" (param i32)");
1114            Location::Indirect(nparams)
1115        }
1116    }
1117}
1118
1119struct Helper<'a>(&'a Type);
1120
1121impl Hash for Helper<'_> {
1122    fn hash<H: Hasher>(&self, h: &mut H) {
1123        std::ptr::hash(self.0, h);
1124    }
1125}
1126
1127impl PartialEq for Helper<'_> {
1128    fn eq(&self, other: &Self) -> bool {
1129        std::ptr::eq(self.0, other.0)
1130    }
1131}
1132
1133impl Eq for Helper<'_> {}
1134
1135fn make_rust_name(name_counter: &mut u32) -> Ident {
1136    let name = format_ident!("Foo{name_counter}");
1137    *name_counter += 1;
1138    name
1139}
1140
1141/// Generate a [`TokenStream`] containing the rust type name for a type.
1142///
1143/// The `name_counter` parameter is used to generate names for each recursively visited type.  The `declarations`
1144/// parameter is used to accumulate declarations for each recursively visited type.
1145pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream {
1146    match ty {
1147        Type::Bool => quote!(bool),
1148        Type::S8 => quote!(i8),
1149        Type::U8 => quote!(u8),
1150        Type::S16 => quote!(i16),
1151        Type::U16 => quote!(u16),
1152        Type::S32 => quote!(i32),
1153        Type::U32 => quote!(u32),
1154        Type::S64 => quote!(i64),
1155        Type::U64 => quote!(u64),
1156        Type::Float32 => quote!(Float32),
1157        Type::Float64 => quote!(Float64),
1158        Type::Char => quote!(char),
1159        Type::String => quote!(Box<str>),
1160        Type::List(ty) => {
1161            let ty = rust_type(ty, name_counter, declarations);
1162            quote!(Vec<#ty>)
1163        }
1164        Type::Record(types) => {
1165            let fields = types
1166                .iter()
1167                .enumerate()
1168                .map(|(index, ty)| {
1169                    let name = format_ident!("f{index}");
1170                    let ty = rust_type(ty, name_counter, declarations);
1171                    quote!(#name: #ty,)
1172                })
1173                .collect::<TokenStream>();
1174
1175            let name = make_rust_name(name_counter);
1176
1177            declarations.extend(quote! {
1178                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1179                #[component(record)]
1180                struct #name {
1181                    #fields
1182                }
1183            });
1184
1185            quote!(#name)
1186        }
1187        Type::Tuple(types) => {
1188            let fields = types
1189                .0
1190                .iter()
1191                .map(|ty| {
1192                    let ty = rust_type(ty, name_counter, declarations);
1193                    quote!(#ty,)
1194                })
1195                .collect::<TokenStream>();
1196
1197            quote!((#fields))
1198        }
1199        Type::Variant(types) => {
1200            let cases = types
1201                .0
1202                .iter()
1203                .enumerate()
1204                .map(|(index, ty)| {
1205                    let name = format_ident!("C{index}");
1206                    let ty = match ty {
1207                        Some(ty) => {
1208                            let ty = rust_type(ty, name_counter, declarations);
1209                            quote!((#ty))
1210                        }
1211                        None => quote!(),
1212                    };
1213                    quote!(#name #ty,)
1214                })
1215                .collect::<TokenStream>();
1216
1217            let name = make_rust_name(name_counter);
1218            declarations.extend(quote! {
1219                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1220                #[component(variant)]
1221                enum #name {
1222                    #cases
1223                }
1224            });
1225
1226            quote!(#name)
1227        }
1228        Type::Enum(count) => {
1229            let cases = (0..*count)
1230                .map(|index| {
1231                    let name = format_ident!("E{index}");
1232                    quote!(#name,)
1233                })
1234                .collect::<TokenStream>();
1235
1236            let name = make_rust_name(name_counter);
1237            let repr = match DiscriminantSize::from_count(*count as usize).unwrap() {
1238                DiscriminantSize::Size1 => quote!(u8),
1239                DiscriminantSize::Size2 => quote!(u16),
1240                DiscriminantSize::Size4 => quote!(u32),
1241            };
1242
1243            declarations.extend(quote! {
1244                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)]
1245                #[component(enum)]
1246                #[repr(#repr)]
1247                enum #name {
1248                    #cases
1249                }
1250            });
1251
1252            quote!(#name)
1253        }
1254        Type::Option(ty) => {
1255            let ty = rust_type(ty, name_counter, declarations);
1256            quote!(Option<#ty>)
1257        }
1258        Type::Result { ok, err } => {
1259            let ok = match ok {
1260                Some(ok) => rust_type(ok, name_counter, declarations),
1261                None => quote!(()),
1262            };
1263            let err = match err {
1264                Some(err) => rust_type(err, name_counter, declarations),
1265                None => quote!(()),
1266            };
1267            quote!(Result<#ok, #err>)
1268        }
1269        Type::Flags(count) => {
1270            let type_name = make_rust_name(name_counter);
1271
1272            let mut flags = TokenStream::new();
1273            let mut names = TokenStream::new();
1274
1275            for index in 0..*count {
1276                let name = format_ident!("F{index}");
1277                flags.extend(quote!(const #name;));
1278                names.extend(quote!(#type_name::#name,))
1279            }
1280
1281            declarations.extend(quote! {
1282                wasmtime::component::flags! {
1283                    #type_name {
1284                        #flags
1285                    }
1286                }
1287
1288                impl<'a> arbitrary::Arbitrary<'a> for #type_name {
1289                    fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1290                        let mut flags = #type_name::default();
1291                        for flag in [#names] {
1292                            if input.arbitrary()? {
1293                                flags |= flag;
1294                            }
1295                        }
1296                        Ok(flags)
1297                    }
1298                }
1299            });
1300
1301            quote!(#type_name)
1302        }
1303    }
1304}
1305
1306#[derive(Default)]
1307struct TypesBuilder<'a> {
1308    next: u32,
1309    worklist: Vec<(u32, &'a Type)>,
1310}
1311
1312impl<'a> TypesBuilder<'a> {
1313    fn write_ref(&mut self, ty: &'a Type, dst: &mut String) {
1314        match ty {
1315            // Primitive types can be referenced directly
1316            Type::Bool => dst.push_str("bool"),
1317            Type::S8 => dst.push_str("s8"),
1318            Type::U8 => dst.push_str("u8"),
1319            Type::S16 => dst.push_str("s16"),
1320            Type::U16 => dst.push_str("u16"),
1321            Type::S32 => dst.push_str("s32"),
1322            Type::U32 => dst.push_str("u32"),
1323            Type::S64 => dst.push_str("s64"),
1324            Type::U64 => dst.push_str("u64"),
1325            Type::Float32 => dst.push_str("float32"),
1326            Type::Float64 => dst.push_str("float64"),
1327            Type::Char => dst.push_str("char"),
1328            Type::String => dst.push_str("string"),
1329
1330            // Otherwise emit a reference to the type and remember to generate
1331            // the corresponding type alias later.
1332            Type::List(_)
1333            | Type::Record(_)
1334            | Type::Tuple(_)
1335            | Type::Variant(_)
1336            | Type::Enum(_)
1337            | Type::Option(_)
1338            | Type::Result { .. }
1339            | Type::Flags(_) => {
1340                let idx = self.next;
1341                self.next += 1;
1342                uwrite!(dst, "$t{idx}");
1343                self.worklist.push((idx, ty));
1344            }
1345        }
1346    }
1347
1348    fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String {
1349        let mut decl = format!("(type $t{idx}' ");
1350        match ty {
1351            Type::Bool
1352            | Type::S8
1353            | Type::U8
1354            | Type::S16
1355            | Type::U16
1356            | Type::S32
1357            | Type::U32
1358            | Type::S64
1359            | Type::U64
1360            | Type::Float32
1361            | Type::Float64
1362            | Type::Char
1363            | Type::String => unreachable!(),
1364
1365            Type::List(ty) => {
1366                decl.push_str("(list ");
1367                self.write_ref(ty, &mut decl);
1368                decl.push_str(")");
1369            }
1370            Type::Record(types) => {
1371                decl.push_str("(record");
1372                for (index, ty) in types.iter().enumerate() {
1373                    uwrite!(decl, r#" (field "f{index}" "#);
1374                    self.write_ref(ty, &mut decl);
1375                    decl.push_str(")");
1376                }
1377                decl.push_str(")");
1378            }
1379            Type::Tuple(types) => {
1380                decl.push_str("(tuple");
1381                for ty in types.iter() {
1382                    decl.push_str(" ");
1383                    self.write_ref(ty, &mut decl);
1384                }
1385                decl.push_str(")");
1386            }
1387            Type::Variant(types) => {
1388                decl.push_str("(variant");
1389                for (index, ty) in types.iter().enumerate() {
1390                    uwrite!(decl, r#" (case "C{index}""#);
1391                    if let Some(ty) = ty {
1392                        decl.push_str(" ");
1393                        self.write_ref(ty, &mut decl);
1394                    }
1395                    decl.push_str(")");
1396                }
1397                decl.push_str(")");
1398            }
1399            Type::Enum(count) => {
1400                decl.push_str("(enum");
1401                for index in 0..*count {
1402                    uwrite!(decl, r#" "E{index}""#);
1403                }
1404                decl.push_str(")");
1405            }
1406            Type::Option(ty) => {
1407                decl.push_str("(option ");
1408                self.write_ref(ty, &mut decl);
1409                decl.push_str(")");
1410            }
1411            Type::Result { ok, err } => {
1412                decl.push_str("(result");
1413                if let Some(ok) = ok {
1414                    decl.push_str(" ");
1415                    self.write_ref(ok, &mut decl);
1416                }
1417                if let Some(err) = err {
1418                    decl.push_str(" (error ");
1419                    self.write_ref(err, &mut decl);
1420                    decl.push_str(")");
1421                }
1422                decl.push_str(")");
1423            }
1424            Type::Flags(count) => {
1425                decl.push_str("(flags");
1426                for index in 0..*count {
1427                    uwrite!(decl, r#" "F{index}""#);
1428                }
1429                decl.push_str(")");
1430            }
1431        }
1432        decl.push_str(")\n");
1433        uwriteln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))");
1434        decl
1435    }
1436}
1437
1438/// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s
1439#[derive(Debug)]
1440pub struct Declarations {
1441    /// Type declarations (if any) referenced by `params` and/or `result`
1442    pub types: Cow<'static, str>,
1443    /// Types to thread through when instantiating sub-components.
1444    pub type_instantiation_args: Cow<'static, str>,
1445    /// Parameter declarations used for the imported and exported functions
1446    pub params: Cow<'static, str>,
1447    /// Result declaration used for the imported and exported functions
1448    pub results: Cow<'static, str>,
1449    /// Implementation of the "caller" component, which invokes the `callee`
1450    /// composed component.
1451    pub caller_module: Cow<'static, str>,
1452    /// Implementation of the "callee" component, which invokes the host.
1453    pub callee_module: Cow<'static, str>,
1454    /// Options used for caller/calle ABI/etc.
1455    pub options: TestCaseOptions,
1456}
1457
1458impl Declarations {
1459    /// Generate a complete WAT file based on the specified fragments.
1460    pub fn make_component(&self) -> Box<str> {
1461        let Self {
1462            types,
1463            type_instantiation_args,
1464            params,
1465            results,
1466            caller_module,
1467            callee_module,
1468            options,
1469        } = self;
1470        let mk_component = |name: &str,
1471                            module: &str,
1472                            import_async: bool,
1473                            export_async: bool,
1474                            encoding: StringEncoding,
1475                            lift_abi: LiftAbi,
1476                            lower_abi: LowerAbi| {
1477            let import_async = if import_async { "async" } else { "" };
1478            let export_async = if export_async { "async" } else { "" };
1479            let lower_async_option = match lower_abi {
1480                LowerAbi::Sync => "",
1481                LowerAbi::Async => "async",
1482            };
1483            let lift_async_option = match lift_abi {
1484                LiftAbi::Sync => "",
1485                LiftAbi::AsyncStackful => "async",
1486                LiftAbi::AsyncCallback => "async (callback (func $i \"callback\"))",
1487            };
1488
1489            let mut intrinsic_defs = String::new();
1490            let mut intrinsic_imports = String::new();
1491
1492            match lift_abi {
1493                LiftAbi::Sync => {}
1494                LiftAbi::AsyncCallback | LiftAbi::AsyncStackful => {
1495                    intrinsic_defs.push_str(&format!(
1496                        r#"
1497(core func $task.return (canon task.return {results}
1498    (memory $libc "memory") string-encoding={encoding}))
1499                        "#,
1500                    ));
1501                    intrinsic_imports.push_str(
1502                        r#"
1503(with "" (instance (export "task.return" (func $task.return))))
1504                        "#,
1505                    );
1506                }
1507            }
1508
1509            format!(
1510                r#"
1511(component ${name}
1512    {types}
1513    (type $import_sig (func {import_async} {params} {results}))
1514    (type $export_sig (func {export_async} {params} {results}))
1515    (import "{IMPORT_FUNCTION}" (func $f (type $import_sig)))
1516
1517    (core instance $libc (instantiate $libc))
1518
1519    (core func $f_lower (canon lower
1520        (func $f)
1521        (memory $libc "memory")
1522        (realloc (func $libc "realloc"))
1523        string-encoding={encoding}
1524        {lower_async_option}
1525    ))
1526
1527    {intrinsic_defs}
1528
1529    (core module $m
1530        (memory (import "libc" "memory") 1)
1531        (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32))
1532
1533        {module}
1534    )
1535
1536    (core instance $i (instantiate $m
1537        (with "libc" (instance $libc))
1538        (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower))))
1539        {intrinsic_imports}
1540    ))
1541
1542    (func (export "{EXPORT_FUNCTION}") (type $export_sig)
1543        (canon lift
1544            (core func $i "{EXPORT_FUNCTION}")
1545            (memory $libc "memory")
1546            (realloc (func $libc "realloc"))
1547            string-encoding={encoding}
1548            {lift_async_option}
1549        )
1550    )
1551)
1552            "#
1553            )
1554        };
1555
1556        let c1 = mk_component(
1557            "callee",
1558            &callee_module,
1559            options.host_async,
1560            options.guest_callee_async,
1561            options.callee_encoding,
1562            options.callee_lift_abi,
1563            options.callee_lower_abi,
1564        );
1565        let c2 = mk_component(
1566            "caller",
1567            &caller_module,
1568            options.guest_callee_async,
1569            options.guest_caller_async,
1570            options.caller_encoding,
1571            options.caller_lift_abi,
1572            options.caller_lower_abi,
1573        );
1574        let host_async = if options.host_async { "async" } else { "" };
1575
1576        format!(
1577            r#"
1578            (component
1579                (core module $libc
1580                    (memory (export "memory") 1)
1581                    {REALLOC_AND_FREE}
1582                )
1583
1584
1585                {types}
1586
1587                (type $host_sig (func {host_async} {params} {results}))
1588                (import "{IMPORT_FUNCTION}" (func $f (type $host_sig)))
1589
1590                {c1}
1591                {c2}
1592                (instance $c1 (instantiate $callee
1593                    {type_instantiation_args}
1594                    (with "{IMPORT_FUNCTION}" (func $f))
1595                ))
1596                (instance $c2 (instantiate $caller
1597                    {type_instantiation_args}
1598                    (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}"))
1599                ))
1600                (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}"))
1601            )"#,
1602        )
1603        .into()
1604    }
1605}
1606
1607/// Represents a test case for calling a component function
1608#[derive(Debug)]
1609pub struct TestCase<'a> {
1610    /// The types of parameters to pass to the function
1611    pub params: Vec<&'a Type>,
1612    /// The result types of the function
1613    pub result: Option<&'a Type>,
1614    /// ABI options to use for this test case.
1615    pub options: TestCaseOptions,
1616}
1617
1618/// Collection of options which configure how the caller/callee/etc ABIs are
1619/// all configured.
1620#[derive(Debug, Arbitrary, Copy, Clone)]
1621pub struct TestCaseOptions {
1622    /// Whether or not the guest caller component (the entrypoint) is using an
1623    /// `async` function type.
1624    pub guest_caller_async: bool,
1625    /// Whether or not the guest callee component (what the entrypoint calls)
1626    /// is using an `async` function type.
1627    pub guest_callee_async: bool,
1628    /// Whether or not the host is using an async function type (what the
1629    /// guest callee calls).
1630    pub host_async: bool,
1631    /// The string encoding of the caller component.
1632    pub caller_encoding: StringEncoding,
1633    /// The string encoding of the callee component.
1634    pub callee_encoding: StringEncoding,
1635    /// The ABI that the caller component is using to lift its export (the main
1636    /// entrypoint).
1637    pub caller_lift_abi: LiftAbi,
1638    /// The ABI that the callee component is using to lift its export (called
1639    /// by the caller).
1640    pub callee_lift_abi: LiftAbi,
1641    /// The ABI that the caller component is using to lower its import (the
1642    /// callee's export).
1643    pub caller_lower_abi: LowerAbi,
1644    /// The ABI that the callee component is using to lower its import (the
1645    /// host function).
1646    pub callee_lower_abi: LowerAbi,
1647}
1648
1649#[derive(Debug, Arbitrary, Copy, Clone)]
1650pub enum LiftAbi {
1651    Sync,
1652    AsyncStackful,
1653    AsyncCallback,
1654}
1655
1656#[derive(Debug, Arbitrary, Copy, Clone)]
1657pub enum LowerAbi {
1658    Sync,
1659    Async,
1660}
1661
1662impl<'a> TestCase<'a> {
1663    pub fn generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
1664        let max_params = if types.len() > 0 { 5 } else { 0 };
1665        let params = (0..u.int_in_range(0..=max_params)?)
1666            .map(|_| u.choose(&types))
1667            .collect::<arbitrary::Result<Vec<_>>>()?;
1668        let result = if types.len() > 0 && u.arbitrary()? {
1669            Some(u.choose(&types)?)
1670        } else {
1671            None
1672        };
1673
1674        let mut options = u.arbitrary::<TestCaseOptions>()?;
1675
1676        // Sync tasks cannot call async functions via a sync lower, nor can they
1677        // block in other ways (e.g. by calling `waitable-set.wait`, returning
1678        // `CALLBACK_CODE_WAIT`, etc.) prior to returning.  Therefore,
1679        // async-ness cascades to the callers:
1680        if options.host_async {
1681            options.guest_callee_async = true;
1682        }
1683        if options.guest_callee_async {
1684            options.guest_caller_async = true;
1685        }
1686
1687        Ok(Self {
1688            params,
1689            result,
1690            options,
1691        })
1692    }
1693
1694    /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case.
1695    pub fn declarations(&self) -> Declarations {
1696        let mut builder = TypesBuilder::default();
1697
1698        let mut params = String::new();
1699        for (i, ty) in self.params.iter().enumerate() {
1700            params.push_str(&format!(" (param \"p{i}\" "));
1701            builder.write_ref(ty, &mut params);
1702            params.push_str(")");
1703        }
1704
1705        let mut results = String::new();
1706        if let Some(ty) = self.result {
1707            results.push_str(&format!(" (result "));
1708            builder.write_ref(ty, &mut results);
1709            results.push_str(")");
1710        }
1711
1712        let caller_module = make_import_and_export(
1713            &self.params,
1714            self.result,
1715            self.options.caller_lift_abi,
1716            self.options.caller_lower_abi,
1717        );
1718        let callee_module = make_import_and_export(
1719            &self.params,
1720            self.result,
1721            self.options.callee_lift_abi,
1722            self.options.callee_lower_abi,
1723        );
1724
1725        let mut type_decls = Vec::new();
1726        let mut type_instantiation_args = String::new();
1727        while let Some((idx, ty)) = builder.worklist.pop() {
1728            type_decls.push(builder.write_decl(idx, ty));
1729            uwriteln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))");
1730        }
1731
1732        // Note that types are printed here in reverse order since they were
1733        // pushed onto `type_decls` as they were referenced meaning the last one
1734        // is the "base" one.
1735        let mut types = String::new();
1736        for decl in type_decls.into_iter().rev() {
1737            types.push_str(&decl);
1738            types.push_str("\n");
1739        }
1740
1741        Declarations {
1742            types: types.into(),
1743            type_instantiation_args: type_instantiation_args.into(),
1744            params: params.into(),
1745            results: results.into(),
1746            caller_module: caller_module.into(),
1747            callee_module: callee_module.into(),
1748            options: self.options,
1749        }
1750    }
1751}
1752
1753#[derive(Copy, Clone, Debug, Arbitrary)]
1754pub enum StringEncoding {
1755    Utf8,
1756    Utf16,
1757    Latin1OrUtf16,
1758}
1759
1760impl fmt::Display for StringEncoding {
1761    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1762        match self {
1763            StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f),
1764            StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f),
1765            StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f),
1766        }
1767    }
1768}
1769
1770impl ToTokens for TestCaseOptions {
1771    fn to_tokens(&self, tokens: &mut TokenStream) {
1772        let TestCaseOptions {
1773            guest_caller_async,
1774            guest_callee_async,
1775            host_async,
1776            caller_encoding,
1777            callee_encoding,
1778            caller_lift_abi,
1779            callee_lift_abi,
1780            caller_lower_abi,
1781            callee_lower_abi,
1782        } = self;
1783        tokens.extend(quote!(wasmtime_test_util::component_fuzz::TestCaseOptions {
1784            guest_caller_async: #guest_caller_async,
1785            guest_callee_async: #guest_callee_async,
1786            host_async: #host_async,
1787            caller_encoding: #caller_encoding,
1788            callee_encoding: #callee_encoding,
1789            caller_lift_abi: #caller_lift_abi,
1790            callee_lift_abi: #callee_lift_abi,
1791            caller_lower_abi: #caller_lower_abi,
1792            callee_lower_abi: #callee_lower_abi,
1793        }));
1794    }
1795}
1796
1797impl ToTokens for LowerAbi {
1798    fn to_tokens(&self, tokens: &mut TokenStream) {
1799        let me = match self {
1800            LowerAbi::Sync => quote!(Sync),
1801            LowerAbi::Async => quote!(Async),
1802        };
1803        tokens.extend(quote!(wasmtime_test_util::component_fuzz::LowerAbi::#me));
1804    }
1805}
1806
1807impl ToTokens for LiftAbi {
1808    fn to_tokens(&self, tokens: &mut TokenStream) {
1809        let me = match self {
1810            LiftAbi::Sync => quote!(Sync),
1811            LiftAbi::AsyncCallback => quote!(AsyncCallback),
1812            LiftAbi::AsyncStackful => quote!(AsyncStackful),
1813        };
1814        tokens.extend(quote!(wasmtime_test_util::component_fuzz::LiftAbi::#me));
1815    }
1816}
1817
1818impl ToTokens for StringEncoding {
1819    fn to_tokens(&self, tokens: &mut TokenStream) {
1820        let me = match self {
1821            StringEncoding::Utf8 => quote!(Utf8),
1822            StringEncoding::Utf16 => quote!(Utf16),
1823            StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16),
1824        };
1825        tokens.extend(quote!(wasmtime_test_util::component_fuzz::StringEncoding::#me));
1826    }
1827}
1828
1829#[cfg(test)]
1830mod tests {
1831    use super::*;
1832
1833    #[test]
1834    fn arbtest() {
1835        arbtest::arbtest(|u| {
1836            let mut fuel = 100;
1837            let types = (0..5)
1838                .map(|_| Type::generate(u, 3, &mut fuel))
1839                .collect::<arbitrary::Result<Vec<_>>>()?;
1840            let case = TestCase::generate(&types, u)?;
1841            let decls = case.declarations();
1842            let component = decls.make_component();
1843            let wasm = wat::parse_str(&component).unwrap_or_else(|e| {
1844                panic!("failed to parse generated component as wat: {e}\n\n{component}");
1845            });
1846            wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all())
1847                .validate_all(&wasm)
1848                .unwrap_or_else(|e| {
1849                    let mut wat = String::new();
1850                    let mut dst = wasmprinter::PrintFmtWrite(&mut wat);
1851                    let to_print = if wasmprinter::Config::new()
1852                        .print_offsets(true)
1853                        .print_operand_stack(true)
1854                        .print(&wasm, &mut dst)
1855                        .is_ok()
1856                    {
1857                        &wat[..]
1858                    } else {
1859                        &component[..]
1860                    };
1861                    panic!("generated component is not valid wasm: {e}\n\n{to_print}");
1862                });
1863            Ok(())
1864        })
1865        .budget_ms(1_000)
1866        // .seed(0x3c9050d4000000e9)
1867        ;
1868    }
1869}