Skip to main content

wasmtime_test_util/
component_fuzz.rs

1//! This module generates test cases for the Wasmtime component model function APIs,
2//! e.g. `wasmtime::component::func::Func` and `TypedFunc`.
3//!
4//! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a
5//! result, and a component which exports a function and imports a function.  The exported function forwards its
6//! parameters to the imported one and forwards the result back to the caller.  This serves to exercise Wasmtime's
7//! lifting and lowering code and verify the values remain intact during both processes.
8
9use arbitrary::{Arbitrary, Unstructured};
10use indexmap::IndexSet;
11use proc_macro2::{Ident, TokenStream};
12use quote::{ToTokens, format_ident, quote};
13use std::borrow::Cow;
14use std::fmt::{self, Debug, Write};
15use std::hash::{Hash, Hasher};
16use std::iter;
17use std::ops::Deref;
18use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE};
19
20const MAX_FLAT_PARAMS: usize = 16;
21const MAX_FLAT_ASYNC_PARAMS: usize = 4;
22const MAX_FLAT_RESULTS: usize = 1;
23
24/// The name of the imported host function which the generated component will call
25pub const IMPORT_FUNCTION: &str = "echo-import";
26
27/// The name of the exported guest function which the host should call
28pub const EXPORT_FUNCTION: &str = "echo-export";
29
30/// Wasmtime allows up to 100 type depth so limit this to just under that.
31pub const MAX_TYPE_DEPTH: u32 = 99;
32
33macro_rules! uwriteln {
34    ($($arg:tt)*) => {
35        writeln!($($arg)*).unwrap()
36    };
37}
38
39macro_rules! uwrite {
40    ($($arg:tt)*) => {
41        write!($($arg)*).unwrap()
42    };
43}
44
45#[derive(Debug, Copy, Clone, PartialEq, Eq)]
46enum CoreType {
47    I32,
48    I64,
49    F32,
50    F64,
51}
52
53impl CoreType {
54    /// This is the `join` operation specified in [the canonical
55    /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for
56    /// variant types.
57    fn join(self, other: Self) -> Self {
58        match (self, other) {
59            _ if self == other => self,
60            (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32,
61            _ => Self::I64,
62        }
63    }
64}
65
66impl fmt::Display for CoreType {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            Self::I32 => f.write_str("i32"),
70            Self::I64 => f.write_str("i64"),
71            Self::F32 => f.write_str("f32"),
72            Self::F64 => f.write_str("f64"),
73        }
74    }
75}
76
77/// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than
78/// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl
79#[derive(Debug, Clone)]
80pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>);
81
82impl<T, const L: u32, const H: u32> VecInRange<T, L, H> {
83    fn new<'a>(
84        input: &mut Unstructured<'a>,
85        fuel: &mut u32,
86        generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>,
87    ) -> arbitrary::Result<Self> {
88        let mut ret = Vec::new();
89        input.arbitrary_loop(Some(L), Some(H), |input| {
90            if *fuel > 0 {
91                *fuel = *fuel - 1;
92                ret.push(generate(input, fuel)?);
93                Ok(std::ops::ControlFlow::Continue(()))
94            } else {
95                Ok(std::ops::ControlFlow::Break(()))
96            }
97        })?;
98        Ok(Self(ret))
99    }
100}
101
102impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> {
103    type Target = [T];
104
105    fn deref(&self) -> &[T] {
106        self.0.deref()
107    }
108}
109
110/// Represents a component model interface type
111#[expect(missing_docs, reason = "self-describing")]
112#[derive(Debug, Clone)]
113pub enum Type {
114    Bool,
115    S8,
116    U8,
117    S16,
118    U16,
119    S32,
120    U32,
121    S64,
122    U64,
123    Float32,
124    Float64,
125    Char,
126    String,
127    List(Box<Type>),
128    Map(Box<Type>, Box<Type>),
129
130    // Give records the ability to generate a generous amount of fields but
131    // don't let the fuzzer go too wild since `wasmparser`'s validator currently
132    // has hard limits in the 1000-ish range on the number of fields a record
133    // may contain.
134    Record(VecInRange<Type, 1, 200>),
135
136    // Tuples can only have up to 16 type parameters in wasmtime right now for
137    // the static API, but the standard library only supports `Debug` up to 11
138    // elements, so compromise at an even 10.
139    Tuple(VecInRange<Type, 1, 10>),
140
141    // Like records, allow a good number of variants, but variants require at
142    // least one case.
143    Variant(VecInRange<Option<Type>, 1, 200>),
144    Enum(u32),
145
146    Option(Box<Type>),
147    Result {
148        ok: Option<Box<Type>>,
149        err: Option<Box<Type>>,
150    },
151
152    Flags(u32),
153}
154
155impl Type {
156    pub fn generate(
157        u: &mut Unstructured<'_>,
158        depth: u32,
159        fuel: &mut u32,
160    ) -> arbitrary::Result<Type> {
161        *fuel = fuel.saturating_sub(1);
162        let max = if depth == 0 || *fuel == 0 { 12 } else { 21 };
163        Ok(match u.int_in_range(0..=max)? {
164            0 => Type::Bool,
165            1 => Type::S8,
166            2 => Type::U8,
167            3 => Type::S16,
168            4 => Type::U16,
169            5 => Type::S32,
170            6 => Type::U32,
171            7 => Type::S64,
172            8 => Type::U64,
173            9 => Type::Float32,
174            10 => Type::Float64,
175            11 => Type::Char,
176            12 => Type::String,
177            // ^-- if you add something here update the `depth == 0` case above
178            13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)),
179            14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?),
180            15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?),
181            16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| {
182                Type::generate_opt(u, depth - 1, fuel)
183            })?),
184            17 => {
185                let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
186                *fuel -= amt;
187                Type::Enum(amt)
188            }
189            18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)),
190            19 => Type::Result {
191                ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
192                err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
193            },
194            20 => {
195                let amt = u.int_in_range(1..=(*fuel).min(32))?;
196                *fuel -= amt;
197                Type::Flags(amt)
198            }
199            21 => Type::Map(
200                Box::new(Type::generate_hashable_key(u, fuel)?),
201                Box::new(Type::generate(u, depth - 1, fuel)?),
202            ),
203            // ^-- if you add something here update the `depth != 0` case above
204            _ => unreachable!(),
205        })
206    }
207
208    /// Generate a type that can be used as a HashMap key (implements Hash + Eq).
209    /// This excludes floats and complex types that might contain floats.
210    fn generate_hashable_key(u: &mut Unstructured<'_>, fuel: &mut u32) -> arbitrary::Result<Type> {
211        *fuel = fuel.saturating_sub(1);
212        // Only generate types that implement Hash and Eq:
213        // - No Float32/Float64 (NaN comparison issues)
214        // - No complex types (Record, Tuple, Variant, etc.) as they might contain floats
215        // - String is allowed as it implements Hash + Eq
216        Ok(match u.int_in_range(0..=11)? {
217            0 => Type::Bool,
218            1 => Type::S8,
219            2 => Type::U8,
220            3 => Type::S16,
221            4 => Type::U16,
222            5 => Type::S32,
223            6 => Type::U32,
224            7 => Type::S64,
225            8 => Type::U64,
226            9 => Type::Char,
227            10 => Type::String,
228            11 => {
229                let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
230                *fuel = fuel.saturating_sub(amt);
231                Type::Enum(amt)
232            }
233            _ => unreachable!(),
234        })
235    }
236
237    fn generate_opt(
238        u: &mut Unstructured<'_>,
239        depth: u32,
240        fuel: &mut u32,
241    ) -> arbitrary::Result<Option<Type>> {
242        Ok(if u.arbitrary()? {
243            Some(Type::generate(u, depth, fuel)?)
244        } else {
245            None
246        })
247    }
248
249    fn generate_list<const L: u32, const H: u32>(
250        u: &mut Unstructured<'_>,
251        depth: u32,
252        fuel: &mut u32,
253    ) -> arbitrary::Result<VecInRange<Type, L, H>> {
254        VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel))
255    }
256
257    /// Generates text format wasm into `s` to store a value of this type, in
258    /// its flat representation stored in the `locals` provided, to the local
259    /// named `ptr` at the `offset` provided.
260    ///
261    /// This will register helper functions necessary in `helpers`. The
262    /// `locals` iterator will be advanced for all locals consumed by this
263    /// store operation.
264    fn store_flat<'a>(
265        &'a self,
266        s: &mut String,
267        ptr: &str,
268        offset: u32,
269        locals: &mut dyn Iterator<Item = FlatSource>,
270        helpers: &mut IndexSet<Helper<'a>>,
271    ) {
272        enum Kind {
273            Primitive(&'static str),
274            PointerPair,
275            Helper,
276        }
277        let kind = match self {
278            Type::Bool | Type::S8 | Type::U8 => Kind::Primitive("i32.store8"),
279            Type::S16 | Type::U16 => Kind::Primitive("i32.store16"),
280            Type::S32 | Type::U32 | Type::Char => Kind::Primitive("i32.store"),
281            Type::S64 | Type::U64 => Kind::Primitive("i64.store"),
282            Type::Float32 => Kind::Primitive("f32.store"),
283            Type::Float64 => Kind::Primitive("f64.store"),
284            Type::String | Type::List(_) | Type::Map(_, _) => Kind::PointerPair,
285            Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.store8"),
286            Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.store16"),
287            Type::Enum(_) => Kind::Primitive("i32.store"),
288            Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.store8"),
289            Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.store16"),
290            Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.store"),
291            Type::Flags(_) => unreachable!(),
292            Type::Record(_)
293            | Type::Tuple(_)
294            | Type::Variant(_)
295            | Type::Option(_)
296            | Type::Result { .. } => Kind::Helper,
297        };
298
299        match kind {
300            Kind::Primitive(op) => uwriteln!(
301                s,
302                "({op} offset={offset} (local.get {ptr}) {})",
303                locals.next().unwrap()
304            ),
305            Kind::PointerPair => {
306                let abi_ptr = locals.next().unwrap();
307                let abi_len = locals.next().unwrap();
308                uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_ptr})",);
309                let offset = offset + 4;
310                uwriteln!(s, "(i32.store offset={offset} (local.get {ptr}) {abi_len})",);
311            }
312            Kind::Helper => {
313                let (index, _) = helpers.insert_full(Helper(self));
314                uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
315                for _ in 0..self.lowered().len() {
316                    let i = locals.next().unwrap();
317                    uwriteln!(s, "{i}");
318                }
319                uwriteln!(s, "call $store_helper_{index}");
320            }
321        }
322    }
323
324    /// Generates a text-format wasm function which takes a pointer and this
325    /// type's flat representation as arguments and then stores this value in
326    /// the first argument.
327    ///
328    /// This is used to store records/variants to cut down on the size of final
329    /// functions and make codegen here a bit easier.
330    fn store_flat_helper<'a>(
331        &'a self,
332        s: &mut String,
333        i: usize,
334        helpers: &mut IndexSet<Helper<'a>>,
335    ) {
336        uwrite!(s, "(func $store_helper_{i} (param i32)");
337        let lowered = self.lowered();
338        for ty in &lowered {
339            uwrite!(s, " (param {ty})");
340        }
341        s.push_str("\n");
342        let locals = (0..lowered.len() as u32).map(|i| i + 1).collect::<Vec<_>>();
343        let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
344            let mut locals = locals.iter().cloned().map(FlatSource::Local);
345            for (offset, ty) in record_field_offsets(types) {
346                ty.store_flat(s, "0", offset, &mut locals, helpers);
347            }
348            assert!(locals.next().is_none());
349        };
350        let variant = |s: &mut String,
351                       helpers: &mut IndexSet<Helper<'a>>,
352                       types: &[Option<&'a Type>]| {
353            let (size, offset) = variant_memory_info(types.iter().cloned());
354            // One extra block for out-of-bounds discriminants.
355            for _ in 0..types.len() + 1 {
356                s.push_str("block\n");
357            }
358
359            // Store the discriminant in memory, then branch on it to figure
360            // out which case we're in.
361            let store = match size {
362                DiscriminantSize::Size1 => "i32.store8",
363                DiscriminantSize::Size2 => "i32.store16",
364                DiscriminantSize::Size4 => "i32.store",
365            };
366            uwriteln!(s, "({store} (local.get 0) (local.get 1))");
367            s.push_str("local.get 1\n");
368            s.push_str("br_table");
369            for i in 0..types.len() + 1 {
370                uwrite!(s, " {i}");
371            }
372            s.push_str("\nend\n");
373
374            // Store each payload individually while converting locals from
375            // their source types to the precise type necessary for this
376            // variant.
377            for ty in types {
378                if let Some(ty) = ty {
379                    let ty_lowered = ty.lowered();
380                    let mut locals = locals[1..].iter().zip(&lowered[1..]).zip(&ty_lowered).map(
381                        |((i, from), to)| FlatSource::LocalConvert {
382                            local: *i,
383                            from: *from,
384                            to: *to,
385                        },
386                    );
387                    ty.store_flat(s, "0", offset, &mut locals, helpers);
388                }
389                s.push_str("return\n");
390                s.push_str("end\n");
391            }
392
393            // Catch-all result which is for out-of-bounds discriminants.
394            s.push_str("unreachable\n");
395        };
396        match self {
397            Type::Bool
398            | Type::S8
399            | Type::U8
400            | Type::S16
401            | Type::U16
402            | Type::S32
403            | Type::U32
404            | Type::Char
405            | Type::S64
406            | Type::U64
407            | Type::Float32
408            | Type::Float64
409            | Type::String
410            | Type::List(_)
411            | Type::Map(_, _)
412            | Type::Flags(_)
413            | Type::Enum(_) => unreachable!(),
414
415            Type::Record(r) => record(s, helpers, r),
416            Type::Tuple(t) => record(s, helpers, t),
417            Type::Variant(v) => variant(
418                s,
419                helpers,
420                &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
421            ),
422            Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
423            Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
424        };
425        s.push_str(")\n");
426    }
427
428    /// Same as `store_flat`, except loads the flat values from `ptr+offset`.
429    ///
430    /// Results are placed directly on the wasm stack.
431    fn load_flat<'a>(
432        &'a self,
433        s: &mut String,
434        ptr: &str,
435        offset: u32,
436        helpers: &mut IndexSet<Helper<'a>>,
437    ) {
438        enum Kind {
439            Primitive(&'static str),
440            PointerPair,
441            Helper,
442        }
443        let kind = match self {
444            Type::Bool | Type::U8 => Kind::Primitive("i32.load8_u"),
445            Type::S8 => Kind::Primitive("i32.load8_s"),
446            Type::U16 => Kind::Primitive("i32.load16_u"),
447            Type::S16 => Kind::Primitive("i32.load16_s"),
448            Type::U32 | Type::S32 | Type::Char => Kind::Primitive("i32.load"),
449            Type::U64 | Type::S64 => Kind::Primitive("i64.load"),
450            Type::Float32 => Kind::Primitive("f32.load"),
451            Type::Float64 => Kind::Primitive("f64.load"),
452            Type::String | Type::List(_) | Type::Map(_, _) => Kind::PointerPair,
453            Type::Enum(n) if *n <= (1 << 8) => Kind::Primitive("i32.load8_u"),
454            Type::Enum(n) if *n <= (1 << 16) => Kind::Primitive("i32.load16_u"),
455            Type::Enum(_) => Kind::Primitive("i32.load"),
456            Type::Flags(n) if *n <= 8 => Kind::Primitive("i32.load8_u"),
457            Type::Flags(n) if *n <= 16 => Kind::Primitive("i32.load16_u"),
458            Type::Flags(n) if *n <= 32 => Kind::Primitive("i32.load"),
459            Type::Flags(_) => unreachable!(),
460
461            Type::Record(_)
462            | Type::Tuple(_)
463            | Type::Variant(_)
464            | Type::Option(_)
465            | Type::Result { .. } => Kind::Helper,
466        };
467        match kind {
468            Kind::Primitive(op) => uwriteln!(s, "({op} offset={offset} (local.get {ptr}))"),
469            Kind::PointerPair => {
470                uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
471                let offset = offset + 4;
472                uwriteln!(s, "(i32.load offset={offset} (local.get {ptr}))",);
473            }
474            Kind::Helper => {
475                let (index, _) = helpers.insert_full(Helper(self));
476                uwriteln!(s, "(i32.add (local.get {ptr}) (i32.const {offset}))");
477                uwriteln!(s, "call $load_helper_{index}");
478            }
479        }
480    }
481
482    /// Same as `store_flat_helper` but for loading the flat representation.
483    fn load_flat_helper<'a>(
484        &'a self,
485        s: &mut String,
486        i: usize,
487        helpers: &mut IndexSet<Helper<'a>>,
488    ) {
489        uwrite!(s, "(func $load_helper_{i} (param i32)");
490        let lowered = self.lowered();
491        for ty in &lowered {
492            uwrite!(s, " (result {ty})");
493        }
494        s.push_str("\n");
495        let record = |s: &mut String, helpers: &mut IndexSet<Helper<'a>>, types: &'a [Type]| {
496            for (offset, ty) in record_field_offsets(types) {
497                ty.load_flat(s, "0", offset, helpers);
498            }
499        };
500        let variant = |s: &mut String,
501                       helpers: &mut IndexSet<Helper<'a>>,
502                       types: &[Option<&'a Type>]| {
503            let (size, offset) = variant_memory_info(types.iter().cloned());
504
505            // Destination locals where the flat representation will be stored.
506            // These are automatically zero which handles unused fields too.
507            for (i, ty) in lowered.iter().enumerate() {
508                uwriteln!(s, " (local $r{i} {ty})");
509            }
510
511            // Return block each case jumps to after setting all locals.
512            s.push_str("block $r\n");
513
514            // One extra block for "out of bounds discriminant".
515            for _ in 0..types.len() + 1 {
516                s.push_str("block\n");
517            }
518
519            // Load the discriminant and branch on it, storing it in
520            // `$r0` as well which is the first flat local representation.
521            let load = match size {
522                DiscriminantSize::Size1 => "i32.load8_u",
523                DiscriminantSize::Size2 => "i32.load16",
524                DiscriminantSize::Size4 => "i32.load",
525            };
526            uwriteln!(s, "({load} (local.get 0))");
527            s.push_str("local.tee $r0\n");
528            s.push_str("br_table");
529            for i in 0..types.len() + 1 {
530                uwrite!(s, " {i}");
531            }
532            s.push_str("\nend\n");
533
534            // For each payload, which is in its own block, load payloads from
535            // memory as necessary and convert them into the final locals.
536            for ty in types {
537                if let Some(ty) = ty {
538                    let ty_lowered = ty.lowered();
539                    ty.load_flat(s, "0", offset, helpers);
540                    for (i, (from, to)) in ty_lowered.iter().zip(&lowered[1..]).enumerate().rev() {
541                        let i = i + 1;
542                        match (from, to) {
543                            (CoreType::F32, CoreType::I32) => {
544                                s.push_str("i32.reinterpret_f32\n");
545                            }
546                            (CoreType::I32, CoreType::I64) => {
547                                s.push_str("i64.extend_i32_u\n");
548                            }
549                            (CoreType::F32, CoreType::I64) => {
550                                s.push_str("i32.reinterpret_f32\n");
551                                s.push_str("i64.extend_i32_u\n");
552                            }
553                            (CoreType::F64, CoreType::I64) => {
554                                s.push_str("i64.reinterpret_f64\n");
555                            }
556                            (a, b) if a == b => {}
557                            _ => unimplemented!("convert {from:?} to {to:?}"),
558                        }
559                        uwriteln!(s, "local.set $r{i}");
560                    }
561                }
562                s.push_str("br $r\n");
563                s.push_str("end\n");
564            }
565
566            // The catch-all block for out-of-bounds discriminants.
567            s.push_str("unreachable\n");
568            s.push_str("end\n");
569            for i in 0..lowered.len() {
570                uwriteln!(s, " local.get $r{i}");
571            }
572        };
573
574        match self {
575            Type::Bool
576            | Type::S8
577            | Type::U8
578            | Type::S16
579            | Type::U16
580            | Type::S32
581            | Type::U32
582            | Type::Char
583            | Type::S64
584            | Type::U64
585            | Type::Float32
586            | Type::Float64
587            | Type::String
588            | Type::List(_)
589            | Type::Map(_, _)
590            | Type::Flags(_)
591            | Type::Enum(_) => unreachable!(),
592
593            Type::Record(r) => record(s, helpers, r),
594            Type::Tuple(t) => record(s, helpers, t),
595            Type::Variant(v) => variant(
596                s,
597                helpers,
598                &v.iter().map(|t| t.as_ref()).collect::<Vec<_>>(),
599            ),
600            Type::Option(o) => variant(s, helpers, &[None, Some(&**o)]),
601            Type::Result { ok, err } => variant(s, helpers, &[ok.as_deref(), err.as_deref()]),
602        };
603        s.push_str(")\n");
604    }
605}
606
607#[derive(Clone)]
608enum FlatSource {
609    Local(u32),
610    LocalConvert {
611        local: u32,
612        from: CoreType,
613        to: CoreType,
614    },
615}
616
617impl fmt::Display for FlatSource {
618    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
619        match self {
620            FlatSource::Local(i) => write!(f, "(local.get {i})"),
621            FlatSource::LocalConvert { local, from, to } => {
622                match (from, to) {
623                    (a, b) if a == b => write!(f, "(local.get {local})"),
624                    (CoreType::I32, CoreType::F32) => {
625                        write!(f, "(f32.reinterpret_i32 (local.get {local}))")
626                    }
627                    (CoreType::I64, CoreType::I32) => {
628                        write!(f, "(i32.wrap_i64 (local.get {local}))")
629                    }
630                    (CoreType::I64, CoreType::F64) => {
631                        write!(f, "(f64.reinterpret_i64 (local.get {local}))")
632                    }
633                    (CoreType::I64, CoreType::F32) => {
634                        write!(
635                            f,
636                            "(f32.reinterpret_i32 (i32.wrap_i64 (local.get {local})))"
637                        )
638                    }
639                    _ => unimplemented!("convert {from:?} to {to:?}"),
640                }
641                // ..
642            }
643        }
644    }
645}
646
647fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) {
648    for ty in types {
649        ty.lower(vec);
650    }
651}
652
653fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) {
654    vec.push(CoreType::I32);
655    let offset = vec.len();
656    for ty in types {
657        let ty = match ty {
658            Some(ty) => ty,
659            None => continue,
660        };
661        for (index, ty) in ty.lowered().iter().enumerate() {
662            let index = offset + index;
663            if index < vec.len() {
664                vec[index] = vec[index].join(*ty);
665            } else {
666                vec.push(*ty)
667            }
668        }
669    }
670}
671
672fn u32_count_from_flag_count(count: usize) -> usize {
673    match FlagsSize::from_count(count) {
674        FlagsSize::Size0 => 0,
675        FlagsSize::Size1 | FlagsSize::Size2 => 1,
676        FlagsSize::Size4Plus(n) => n.into(),
677    }
678}
679
680struct SizeAndAlignment {
681    size: usize,
682    alignment: u32,
683}
684
685impl Type {
686    fn lowered(&self) -> Vec<CoreType> {
687        let mut vec = Vec::new();
688        self.lower(&mut vec);
689        vec
690    }
691
692    fn lower(&self, vec: &mut Vec<CoreType>) {
693        match self {
694            Type::Bool
695            | Type::U8
696            | Type::S8
697            | Type::S16
698            | Type::U16
699            | Type::S32
700            | Type::U32
701            | Type::Char
702            | Type::Enum(_) => vec.push(CoreType::I32),
703            Type::S64 | Type::U64 => vec.push(CoreType::I64),
704            Type::Float32 => vec.push(CoreType::F32),
705            Type::Float64 => vec.push(CoreType::F64),
706            Type::String | Type::List(_) | Type::Map(_, _) => {
707                vec.push(CoreType::I32);
708                vec.push(CoreType::I32);
709            }
710            Type::Record(types) => lower_record(types.iter(), vec),
711            Type::Tuple(types) => lower_record(types.0.iter(), vec),
712            Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec),
713            Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec),
714            Type::Result { ok, err } => {
715                lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec)
716            }
717            Type::Flags(count) => vec.extend(
718                iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)),
719            ),
720        }
721    }
722
723    fn size_and_alignment(&self) -> SizeAndAlignment {
724        match self {
725            Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
726                size: 1,
727                alignment: 1,
728            },
729
730            Type::S16 | Type::U16 => SizeAndAlignment {
731                size: 2,
732                alignment: 2,
733            },
734
735            Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
736                size: 4,
737                alignment: 4,
738            },
739
740            Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
741                size: 8,
742                alignment: 8,
743            },
744
745            Type::String | Type::List(_) | Type::Map(_, _) => SizeAndAlignment {
746                size: 8,
747                alignment: 4,
748            },
749
750            Type::Record(types) => record_size_and_alignment(types.iter()),
751
752            Type::Tuple(types) => record_size_and_alignment(types.0.iter()),
753
754            Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())),
755
756            Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)),
757
758            Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()),
759
760            Type::Result { ok, err } => {
761                variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter())
762            }
763
764            Type::Flags(count) => match FlagsSize::from_count(*count as usize) {
765                FlagsSize::Size0 => SizeAndAlignment {
766                    size: 0,
767                    alignment: 1,
768                },
769                FlagsSize::Size1 => SizeAndAlignment {
770                    size: 1,
771                    alignment: 1,
772                },
773                FlagsSize::Size2 => SizeAndAlignment {
774                    size: 2,
775                    alignment: 2,
776                },
777                FlagsSize::Size4Plus(n) => SizeAndAlignment {
778                    size: usize::from(n) * 4,
779                    alignment: 4,
780                },
781            },
782        }
783    }
784}
785
786fn align_to(a: usize, align: u32) -> usize {
787    let align = align as usize;
788    (a + (align - 1)) & !(align - 1)
789}
790
791fn record_field_offsets<'a>(
792    types: impl IntoIterator<Item = &'a Type>,
793) -> impl Iterator<Item = (u32, &'a Type)> {
794    let mut offset = 0;
795    types.into_iter().map(move |ty| {
796        let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
797        let ret = align_to(offset, alignment);
798        offset = ret + size;
799        (ret as u32, ty)
800    })
801}
802
803fn record_size_and_alignment<'a>(types: impl IntoIterator<Item = &'a Type>) -> SizeAndAlignment {
804    let mut offset = 0;
805    let mut align = 1;
806    for ty in types {
807        let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
808        offset = align_to(offset, alignment) + size;
809        align = align.max(alignment);
810    }
811
812    SizeAndAlignment {
813        size: align_to(offset, align),
814        alignment: align,
815    }
816}
817
818fn variant_size_and_alignment<'a>(
819    types: impl ExactSizeIterator<Item = Option<&'a Type>>,
820) -> SizeAndAlignment {
821    let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
822    let mut alignment = u32::from(discriminant_size);
823    let mut size = 0;
824    for ty in types {
825        if let Some(ty) = ty {
826            let size_and_alignment = ty.size_and_alignment();
827            alignment = alignment.max(size_and_alignment.alignment);
828            size = size.max(size_and_alignment.size);
829        }
830    }
831
832    SizeAndAlignment {
833        size: align_to(
834            align_to(usize::from(discriminant_size), alignment) + size,
835            alignment,
836        ),
837        alignment,
838    }
839}
840
841fn variant_memory_info<'a>(
842    types: impl ExactSizeIterator<Item = Option<&'a Type>>,
843) -> (DiscriminantSize, u32) {
844    let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
845    let mut alignment = u32::from(discriminant_size);
846    for ty in types {
847        if let Some(ty) = ty {
848            let size_and_alignment = ty.size_and_alignment();
849            alignment = alignment.max(size_and_alignment.alignment);
850        }
851    }
852
853    (
854        discriminant_size,
855        align_to(usize::from(discriminant_size), alignment) as u32,
856    )
857}
858
859/// Generates the internals of a core wasm module which imports a single
860/// component function `IMPORT_FUNCTION` and exports a single component
861/// function `EXPORT_FUNCTION`.
862///
863/// The component function takes `params` as arguments and optionally returns
864/// `result`. The `lift_abi` and `lower_abi` fields indicate the ABI in-use for
865/// this operation.
866fn make_import_and_export(
867    params: &[&Type],
868    result: Option<&Type>,
869    lift_abi: LiftAbi,
870    lower_abi: LowerAbi,
871) -> String {
872    let params_lowered = params
873        .iter()
874        .flat_map(|ty| ty.lowered())
875        .collect::<Box<[_]>>();
876    let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new());
877
878    let mut wat = String::new();
879
880    enum Location {
881        Flat,
882        Indirect(u32),
883    }
884
885    // Generate the core wasm type corresponding to the imported function being
886    // lowered with `lower_abi`.
887    wat.push_str(&format!("(type $import (func"));
888    let max_import_params = match lower_abi {
889        LowerAbi::Sync => MAX_FLAT_PARAMS,
890        LowerAbi::Async => MAX_FLAT_ASYNC_PARAMS,
891    };
892    let (import_params_loc, nparams) = push_params(&mut wat, &params_lowered, max_import_params);
893    let import_results_loc = match lower_abi {
894        LowerAbi::Sync => {
895            push_result_or_retptr(&mut wat, &result_lowered, nparams, MAX_FLAT_RESULTS)
896        }
897        LowerAbi::Async => {
898            let loc = if result.is_none() {
899                Location::Flat
900            } else {
901                wat.push_str(" (param i32)"); // result pointer
902                Location::Indirect(nparams)
903            };
904            wat.push_str(" (result i32)"); // status code
905            loc
906        }
907    };
908    wat.push_str("))\n");
909
910    // Generate the import function.
911    wat.push_str(&format!(
912        r#"(import "host" "{IMPORT_FUNCTION}" (func $host (type $import)))"#
913    ));
914
915    // Do the same as above for the exported function's type which is lifted
916    // with `lift_abi`.
917    //
918    // Note that `export_results_loc` being `None` means that `task.return` is
919    // used to communicate results.
920    wat.push_str(&format!("(type $export (func"));
921    let (export_params_loc, _nparams) = push_params(&mut wat, &params_lowered, MAX_FLAT_PARAMS);
922    let export_results_loc = match lift_abi {
923        LiftAbi::Sync => Some(push_group(&mut wat, "result", &result_lowered, MAX_FLAT_RESULTS).0),
924        LiftAbi::AsyncCallback => {
925            wat.push_str(" (result i32)"); // status code
926            None
927        }
928        LiftAbi::AsyncStackful => None,
929    };
930    wat.push_str("))\n");
931
932    // If the export is async, generate `task.return` as an import as well
933    // which is necesary to communicate the results.
934    if export_results_loc.is_none() {
935        wat.push_str(&format!("(type $task.return (func"));
936        push_params(&mut wat, &result_lowered, MAX_FLAT_PARAMS);
937        wat.push_str("))\n");
938        wat.push_str(&format!(
939            r#"(import "" "task.return" (func $task.return (type $task.return)))"#
940        ));
941    }
942
943    wat.push_str(&format!(
944        r#"
945(func (export "{EXPORT_FUNCTION}") (type $export)
946    (local $retptr i32)
947    (local $argptr i32)
948        "#
949    ));
950    let mut store_helpers = IndexSet::new();
951    let mut load_helpers = IndexSet::new();
952
953    match (export_params_loc, import_params_loc) {
954        // flat => flat is just moving locals around
955        (Location::Flat, Location::Flat) => {
956            for (index, _) in params_lowered.iter().enumerate() {
957                uwrite!(wat, "local.get {index}\n");
958            }
959        }
960
961        // indirect => indirect is just moving locals around
962        (Location::Indirect(i), Location::Indirect(j)) => {
963            assert_eq!(j, 0);
964            uwrite!(wat, "local.get {i}\n");
965        }
966
967        // flat => indirect means that all parameters are stored in memory as
968        // if it was a record of all the parameters.
969        (Location::Flat, Location::Indirect(_)) => {
970            let SizeAndAlignment { size, alignment } =
971                record_size_and_alignment(params.iter().cloned());
972            wat.push_str(&format!(
973                r#"
974                    (local.set $argptr
975                        (call $realloc
976                            (i32.const 0)
977                            (i32.const 0)
978                            (i32.const {alignment})
979                            (i32.const {size})))
980                    local.get $argptr
981                "#
982            ));
983            let mut locals = (0..params_lowered.len() as u32).map(FlatSource::Local);
984            for (offset, ty) in record_field_offsets(params.iter().cloned()) {
985                ty.store_flat(&mut wat, "$argptr", offset, &mut locals, &mut store_helpers);
986            }
987            assert!(locals.next().is_none());
988        }
989
990        (Location::Indirect(_), Location::Flat) => unreachable!(),
991    }
992
993    // Pass a return-pointer if necessary.
994    match import_results_loc {
995        Location::Flat => {}
996        Location::Indirect(_) => {
997            let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment();
998
999            wat.push_str(&format!(
1000                r#"
1001                    (local.set $retptr
1002                        (call $realloc
1003                            (i32.const 0)
1004                            (i32.const 0)
1005                            (i32.const {alignment})
1006                            (i32.const {size})))
1007                    local.get $retptr
1008                "#
1009            ));
1010        }
1011    }
1012
1013    wat.push_str("call $host\n");
1014
1015    // Assert the lowered call is ready if an async code was returned.
1016    //
1017    // TODO: handle when the import isn't ready yet
1018    if let LowerAbi::Async = lower_abi {
1019        wat.push_str("i32.const 2\n");
1020        wat.push_str("i32.ne\n");
1021        wat.push_str("if unreachable end\n");
1022    }
1023
1024    // TODO: conditionally inject a yield here
1025
1026    match (import_results_loc, export_results_loc) {
1027        // flat => flat results involves nothing, the results are already on
1028        // the stack.
1029        (Location::Flat, Some(Location::Flat)) => {}
1030
1031        // indirect => indirect results requires returning the `$retptr` the
1032        // host call filled in.
1033        (Location::Indirect(_), Some(Location::Indirect(_))) => {
1034            wat.push_str("local.get $retptr\n");
1035        }
1036
1037        // indirect => flat requires loading the result from the return pointer
1038        (Location::Indirect(_), Some(Location::Flat)) => {
1039            result
1040                .unwrap()
1041                .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1042        }
1043
1044        // flat => task.return is easy, the results are already there so just
1045        // call the function.
1046        (Location::Flat, None) => {
1047            wat.push_str("call $task.return\n");
1048        }
1049
1050        // indirect => task.return needs to forward `$retptr` if the results
1051        // are indirect, or otherwise it must be loaded from memory to a flat
1052        // representation.
1053        (Location::Indirect(_), None) => {
1054            if result_lowered.len() <= MAX_FLAT_PARAMS {
1055                result
1056                    .unwrap()
1057                    .load_flat(&mut wat, "$retptr", 0, &mut load_helpers);
1058            } else {
1059                wat.push_str("local.get $retptr\n");
1060            }
1061            wat.push_str("call $task.return\n");
1062        }
1063
1064        (Location::Flat, Some(Location::Indirect(_))) => unreachable!(),
1065    }
1066
1067    if let LiftAbi::AsyncCallback = lift_abi {
1068        wat.push_str("i32.const 0\n"); // completed status code
1069    }
1070
1071    wat.push_str(")\n");
1072
1073    // Generate a `callback` function for the callback ABI.
1074    //
1075    // TODO: fill this in
1076    if let LiftAbi::AsyncCallback = lift_abi {
1077        wat.push_str(
1078            r#"
1079(func (export "callback") (param i32 i32 i32) (result i32) unreachable)
1080            "#,
1081        );
1082    }
1083
1084    // Fill out all store/load helpers that were needed during generation
1085    // above. This is a fix-point-loop since each helper may end up requiring
1086    // more helpers.
1087    let mut i = 0;
1088    while i < store_helpers.len() {
1089        let ty = store_helpers[i].0;
1090        ty.store_flat_helper(&mut wat, i, &mut store_helpers);
1091        i += 1;
1092    }
1093    i = 0;
1094    while i < load_helpers.len() {
1095        let ty = load_helpers[i].0;
1096        ty.load_flat_helper(&mut wat, i, &mut load_helpers);
1097        i += 1;
1098    }
1099
1100    return wat;
1101
1102    fn push_params(wat: &mut String, params: &[CoreType], max_flat: usize) -> (Location, u32) {
1103        push_group(wat, "param", params, max_flat)
1104    }
1105
1106    fn push_group(
1107        wat: &mut String,
1108        name: &str,
1109        params: &[CoreType],
1110        max_flat: usize,
1111    ) -> (Location, u32) {
1112        let mut nparams = 0;
1113        let loc = if params.is_empty() {
1114            // nothing to emit...
1115            Location::Flat
1116        } else if params.len() <= max_flat {
1117            wat.push_str(&format!(" ({name}"));
1118            for ty in params {
1119                wat.push_str(&format!(" {ty}"));
1120                nparams += 1;
1121            }
1122            wat.push_str(")");
1123            Location::Flat
1124        } else {
1125            wat.push_str(&format!(" ({name} i32)"));
1126            nparams += 1;
1127            Location::Indirect(0)
1128        };
1129        (loc, nparams)
1130    }
1131
1132    fn push_result_or_retptr(
1133        wat: &mut String,
1134        results: &[CoreType],
1135        nparams: u32,
1136        max_flat: usize,
1137    ) -> Location {
1138        if results.is_empty() {
1139            // nothing to emit...
1140            Location::Flat
1141        } else if results.len() <= max_flat {
1142            wat.push_str(" (result");
1143            for ty in results {
1144                wat.push_str(&format!(" {ty}"));
1145            }
1146            wat.push_str(")");
1147            Location::Flat
1148        } else {
1149            wat.push_str(" (param i32)");
1150            Location::Indirect(nparams)
1151        }
1152    }
1153}
1154
1155struct Helper<'a>(&'a Type);
1156
1157impl Hash for Helper<'_> {
1158    fn hash<H: Hasher>(&self, h: &mut H) {
1159        std::ptr::hash(self.0, h);
1160    }
1161}
1162
1163impl PartialEq for Helper<'_> {
1164    fn eq(&self, other: &Self) -> bool {
1165        std::ptr::eq(self.0, other.0)
1166    }
1167}
1168
1169impl Eq for Helper<'_> {}
1170
1171fn make_rust_name(name_counter: &mut u32) -> Ident {
1172    let name = format_ident!("Foo{name_counter}");
1173    *name_counter += 1;
1174    name
1175}
1176
1177/// Generate a [`TokenStream`] containing the rust type name for a type.
1178///
1179/// The `name_counter` parameter is used to generate names for each recursively visited type.  The `declarations`
1180/// parameter is used to accumulate declarations for each recursively visited type.
1181pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream {
1182    match ty {
1183        Type::Bool => quote!(bool),
1184        Type::S8 => quote!(i8),
1185        Type::U8 => quote!(u8),
1186        Type::S16 => quote!(i16),
1187        Type::U16 => quote!(u16),
1188        Type::S32 => quote!(i32),
1189        Type::U32 => quote!(u32),
1190        Type::S64 => quote!(i64),
1191        Type::U64 => quote!(u64),
1192        Type::Float32 => quote!(Float32),
1193        Type::Float64 => quote!(Float64),
1194        Type::Char => quote!(char),
1195        Type::String => quote!(Box<str>),
1196        Type::List(ty) => {
1197            let ty = rust_type(ty, name_counter, declarations);
1198            quote!(Vec<#ty>)
1199        }
1200        Type::Map(key_ty, value_ty) => {
1201            let key_ty = rust_type(key_ty, name_counter, declarations);
1202            let value_ty = rust_type(value_ty, name_counter, declarations);
1203            quote!(std::collections::HashMap<#key_ty, #value_ty>)
1204        }
1205        Type::Record(types) => {
1206            let fields = types
1207                .iter()
1208                .enumerate()
1209                .map(|(index, ty)| {
1210                    let name = format_ident!("f{index}");
1211                    let ty = rust_type(ty, name_counter, declarations);
1212                    quote!(#name: #ty,)
1213                })
1214                .collect::<TokenStream>();
1215
1216            let name = make_rust_name(name_counter);
1217
1218            declarations.extend(quote! {
1219                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1220                #[component(record)]
1221                struct #name {
1222                    #fields
1223                }
1224            });
1225
1226            quote!(#name)
1227        }
1228        Type::Tuple(types) => {
1229            let fields = types
1230                .0
1231                .iter()
1232                .map(|ty| {
1233                    let ty = rust_type(ty, name_counter, declarations);
1234                    quote!(#ty,)
1235                })
1236                .collect::<TokenStream>();
1237
1238            quote!((#fields))
1239        }
1240        Type::Variant(types) => {
1241            let cases = types
1242                .0
1243                .iter()
1244                .enumerate()
1245                .map(|(index, ty)| {
1246                    let name = format_ident!("C{index}");
1247                    let ty = match ty {
1248                        Some(ty) => {
1249                            let ty = rust_type(ty, name_counter, declarations);
1250                            quote!((#ty))
1251                        }
1252                        None => quote!(),
1253                    };
1254                    quote!(#name #ty,)
1255                })
1256                .collect::<TokenStream>();
1257
1258            let name = make_rust_name(name_counter);
1259            declarations.extend(quote! {
1260                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
1261                #[component(variant)]
1262                enum #name {
1263                    #cases
1264                }
1265            });
1266
1267            quote!(#name)
1268        }
1269        Type::Enum(count) => {
1270            let cases = (0..*count)
1271                .map(|index| {
1272                    let name = format_ident!("E{index}");
1273                    quote!(#name,)
1274                })
1275                .collect::<TokenStream>();
1276
1277            let name = make_rust_name(name_counter);
1278            let repr = match DiscriminantSize::from_count(*count as usize).unwrap() {
1279                DiscriminantSize::Size1 => quote!(u8),
1280                DiscriminantSize::Size2 => quote!(u16),
1281                DiscriminantSize::Size4 => quote!(u32),
1282            };
1283
1284            declarations.extend(quote! {
1285                #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Hash, Debug, Copy, Clone, Arbitrary)]
1286                #[component(enum)]
1287                #[repr(#repr)]
1288                enum #name {
1289                    #cases
1290                }
1291            });
1292
1293            quote!(#name)
1294        }
1295        Type::Option(ty) => {
1296            let ty = rust_type(ty, name_counter, declarations);
1297            quote!(Option<#ty>)
1298        }
1299        Type::Result { ok, err } => {
1300            let ok = match ok {
1301                Some(ok) => rust_type(ok, name_counter, declarations),
1302                None => quote!(()),
1303            };
1304            let err = match err {
1305                Some(err) => rust_type(err, name_counter, declarations),
1306                None => quote!(()),
1307            };
1308            quote!(Result<#ok, #err>)
1309        }
1310        Type::Flags(count) => {
1311            let type_name = make_rust_name(name_counter);
1312
1313            let mut flags = TokenStream::new();
1314            let mut names = TokenStream::new();
1315
1316            for index in 0..*count {
1317                let name = format_ident!("F{index}");
1318                flags.extend(quote!(const #name;));
1319                names.extend(quote!(#type_name::#name,))
1320            }
1321
1322            declarations.extend(quote! {
1323                wasmtime::component::flags! {
1324                    #type_name {
1325                        #flags
1326                    }
1327                }
1328
1329                impl<'a> arbitrary::Arbitrary<'a> for #type_name {
1330                    fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
1331                        let mut flags = #type_name::default();
1332                        for flag in [#names] {
1333                            if input.arbitrary()? {
1334                                flags |= flag;
1335                            }
1336                        }
1337                        Ok(flags)
1338                    }
1339                }
1340            });
1341
1342            quote!(#type_name)
1343        }
1344    }
1345}
1346
1347#[derive(Default)]
1348struct TypesBuilder<'a> {
1349    next: u32,
1350    worklist: Vec<(u32, &'a Type)>,
1351}
1352
1353impl<'a> TypesBuilder<'a> {
1354    fn write_ref(&mut self, ty: &'a Type, dst: &mut String) {
1355        match ty {
1356            // Primitive types can be referenced directly
1357            Type::Bool => dst.push_str("bool"),
1358            Type::S8 => dst.push_str("s8"),
1359            Type::U8 => dst.push_str("u8"),
1360            Type::S16 => dst.push_str("s16"),
1361            Type::U16 => dst.push_str("u16"),
1362            Type::S32 => dst.push_str("s32"),
1363            Type::U32 => dst.push_str("u32"),
1364            Type::S64 => dst.push_str("s64"),
1365            Type::U64 => dst.push_str("u64"),
1366            Type::Float32 => dst.push_str("float32"),
1367            Type::Float64 => dst.push_str("float64"),
1368            Type::Char => dst.push_str("char"),
1369            Type::String => dst.push_str("string"),
1370
1371            // Otherwise emit a reference to the type and remember to generate
1372            // the corresponding type alias later.
1373            Type::List(_)
1374            | Type::Map(_, _)
1375            | Type::Record(_)
1376            | Type::Tuple(_)
1377            | Type::Variant(_)
1378            | Type::Enum(_)
1379            | Type::Option(_)
1380            | Type::Result { .. }
1381            | Type::Flags(_) => {
1382                let idx = self.next;
1383                self.next += 1;
1384                uwrite!(dst, "$t{idx}");
1385                self.worklist.push((idx, ty));
1386            }
1387        }
1388    }
1389
1390    fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String {
1391        let mut decl = format!("(type $t{idx}' ");
1392        match ty {
1393            Type::Bool
1394            | Type::S8
1395            | Type::U8
1396            | Type::S16
1397            | Type::U16
1398            | Type::S32
1399            | Type::U32
1400            | Type::S64
1401            | Type::U64
1402            | Type::Float32
1403            | Type::Float64
1404            | Type::Char
1405            | Type::String => unreachable!(),
1406
1407            Type::List(ty) => {
1408                decl.push_str("(list ");
1409                self.write_ref(ty, &mut decl);
1410                decl.push_str(")");
1411            }
1412            Type::Map(key_ty, value_ty) => {
1413                decl.push_str("(map ");
1414                self.write_ref(key_ty, &mut decl);
1415                decl.push_str(" ");
1416                self.write_ref(value_ty, &mut decl);
1417                decl.push_str(")");
1418            }
1419            Type::Record(types) => {
1420                decl.push_str("(record");
1421                for (index, ty) in types.iter().enumerate() {
1422                    uwrite!(decl, r#" (field "f{index}" "#);
1423                    self.write_ref(ty, &mut decl);
1424                    decl.push_str(")");
1425                }
1426                decl.push_str(")");
1427            }
1428            Type::Tuple(types) => {
1429                decl.push_str("(tuple");
1430                for ty in types.iter() {
1431                    decl.push_str(" ");
1432                    self.write_ref(ty, &mut decl);
1433                }
1434                decl.push_str(")");
1435            }
1436            Type::Variant(types) => {
1437                decl.push_str("(variant");
1438                for (index, ty) in types.iter().enumerate() {
1439                    uwrite!(decl, r#" (case "C{index}""#);
1440                    if let Some(ty) = ty {
1441                        decl.push_str(" ");
1442                        self.write_ref(ty, &mut decl);
1443                    }
1444                    decl.push_str(")");
1445                }
1446                decl.push_str(")");
1447            }
1448            Type::Enum(count) => {
1449                decl.push_str("(enum");
1450                for index in 0..*count {
1451                    uwrite!(decl, r#" "E{index}""#);
1452                }
1453                decl.push_str(")");
1454            }
1455            Type::Option(ty) => {
1456                decl.push_str("(option ");
1457                self.write_ref(ty, &mut decl);
1458                decl.push_str(")");
1459            }
1460            Type::Result { ok, err } => {
1461                decl.push_str("(result");
1462                if let Some(ok) = ok {
1463                    decl.push_str(" ");
1464                    self.write_ref(ok, &mut decl);
1465                }
1466                if let Some(err) = err {
1467                    decl.push_str(" (error ");
1468                    self.write_ref(err, &mut decl);
1469                    decl.push_str(")");
1470                }
1471                decl.push_str(")");
1472            }
1473            Type::Flags(count) => {
1474                decl.push_str("(flags");
1475                for index in 0..*count {
1476                    uwrite!(decl, r#" "F{index}""#);
1477                }
1478                decl.push_str(")");
1479            }
1480        }
1481        decl.push_str(")\n");
1482        uwriteln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))");
1483        decl
1484    }
1485}
1486
1487/// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s
1488#[derive(Debug)]
1489pub struct Declarations {
1490    /// Type declarations (if any) referenced by `params` and/or `result`
1491    pub types: Cow<'static, str>,
1492    /// Types to thread through when instantiating sub-components.
1493    pub type_instantiation_args: Cow<'static, str>,
1494    /// Parameter declarations used for the imported and exported functions
1495    pub params: Cow<'static, str>,
1496    /// Result declaration used for the imported and exported functions
1497    pub results: Cow<'static, str>,
1498    /// Implementation of the "caller" component, which invokes the `callee`
1499    /// composed component.
1500    pub caller_module: Cow<'static, str>,
1501    /// Implementation of the "callee" component, which invokes the host.
1502    pub callee_module: Cow<'static, str>,
1503    /// Options used for caller/calle ABI/etc.
1504    pub options: TestCaseOptions,
1505}
1506
1507impl Declarations {
1508    /// Generate a complete WAT file based on the specified fragments.
1509    pub fn make_component(&self) -> Box<str> {
1510        let Self {
1511            types,
1512            type_instantiation_args,
1513            params,
1514            results,
1515            caller_module,
1516            callee_module,
1517            options,
1518        } = self;
1519        let mk_component = |name: &str,
1520                            module: &str,
1521                            import_async: bool,
1522                            export_async: bool,
1523                            encoding: StringEncoding,
1524                            lift_abi: LiftAbi,
1525                            lower_abi: LowerAbi| {
1526            let import_async = if import_async { "async" } else { "" };
1527            let export_async = if export_async { "async" } else { "" };
1528            let lower_async_option = match lower_abi {
1529                LowerAbi::Sync => "",
1530                LowerAbi::Async => "async",
1531            };
1532            let lift_async_option = match lift_abi {
1533                LiftAbi::Sync => "",
1534                LiftAbi::AsyncStackful => "async",
1535                LiftAbi::AsyncCallback => "async (callback (func $i \"callback\"))",
1536            };
1537
1538            let mut intrinsic_defs = String::new();
1539            let mut intrinsic_imports = String::new();
1540
1541            match lift_abi {
1542                LiftAbi::Sync => {}
1543                LiftAbi::AsyncCallback | LiftAbi::AsyncStackful => {
1544                    intrinsic_defs.push_str(&format!(
1545                        r#"
1546(core func $task.return (canon task.return {results}
1547    (memory $libc "memory") string-encoding={encoding}))
1548                        "#,
1549                    ));
1550                    intrinsic_imports.push_str(
1551                        r#"
1552(with "" (instance (export "task.return" (func $task.return))))
1553                        "#,
1554                    );
1555                }
1556            }
1557
1558            format!(
1559                r#"
1560(component ${name}
1561    {types}
1562    (type $import_sig (func {import_async} {params} {results}))
1563    (type $export_sig (func {export_async} {params} {results}))
1564    (import "{IMPORT_FUNCTION}" (func $f (type $import_sig)))
1565
1566    (core instance $libc (instantiate $libc))
1567
1568    (core func $f_lower (canon lower
1569        (func $f)
1570        (memory $libc "memory")
1571        (realloc (func $libc "realloc"))
1572        string-encoding={encoding}
1573        {lower_async_option}
1574    ))
1575
1576    {intrinsic_defs}
1577
1578    (core module $m
1579        (memory (import "libc" "memory") 1)
1580        (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32))
1581
1582        {module}
1583    )
1584
1585    (core instance $i (instantiate $m
1586        (with "libc" (instance $libc))
1587        (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower))))
1588        {intrinsic_imports}
1589    ))
1590
1591    (func (export "{EXPORT_FUNCTION}") (type $export_sig)
1592        (canon lift
1593            (core func $i "{EXPORT_FUNCTION}")
1594            (memory $libc "memory")
1595            (realloc (func $libc "realloc"))
1596            string-encoding={encoding}
1597            {lift_async_option}
1598        )
1599    )
1600)
1601            "#
1602            )
1603        };
1604
1605        let c1 = mk_component(
1606            "callee",
1607            &callee_module,
1608            options.host_async,
1609            options.guest_callee_async,
1610            options.callee_encoding,
1611            options.callee_lift_abi,
1612            options.callee_lower_abi,
1613        );
1614        let c2 = mk_component(
1615            "caller",
1616            &caller_module,
1617            options.guest_callee_async,
1618            options.guest_caller_async,
1619            options.caller_encoding,
1620            options.caller_lift_abi,
1621            options.caller_lower_abi,
1622        );
1623        let host_async = if options.host_async { "async" } else { "" };
1624
1625        format!(
1626            r#"
1627            (component
1628                (core module $libc
1629                    (memory (export "memory") 1)
1630                    {REALLOC_AND_FREE}
1631                )
1632
1633
1634                {types}
1635
1636                (type $host_sig (func {host_async} {params} {results}))
1637                (import "{IMPORT_FUNCTION}" (func $f (type $host_sig)))
1638
1639                {c1}
1640                {c2}
1641                (instance $c1 (instantiate $callee
1642                    {type_instantiation_args}
1643                    (with "{IMPORT_FUNCTION}" (func $f))
1644                ))
1645                (instance $c2 (instantiate $caller
1646                    {type_instantiation_args}
1647                    (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}"))
1648                ))
1649                (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}"))
1650            )"#,
1651        )
1652        .into()
1653    }
1654}
1655
1656/// Represents a test case for calling a component function
1657#[derive(Debug)]
1658pub struct TestCase<'a> {
1659    /// The types of parameters to pass to the function
1660    pub params: Vec<&'a Type>,
1661    /// The result types of the function
1662    pub result: Option<&'a Type>,
1663    /// ABI options to use for this test case.
1664    pub options: TestCaseOptions,
1665}
1666
1667/// Collection of options which configure how the caller/callee/etc ABIs are
1668/// all configured.
1669#[derive(Debug, Arbitrary, Copy, Clone)]
1670pub struct TestCaseOptions {
1671    /// Whether or not the guest caller component (the entrypoint) is using an
1672    /// `async` function type.
1673    pub guest_caller_async: bool,
1674    /// Whether or not the guest callee component (what the entrypoint calls)
1675    /// is using an `async` function type.
1676    pub guest_callee_async: bool,
1677    /// Whether or not the host is using an async function type (what the
1678    /// guest callee calls).
1679    pub host_async: bool,
1680    /// The string encoding of the caller component.
1681    pub caller_encoding: StringEncoding,
1682    /// The string encoding of the callee component.
1683    pub callee_encoding: StringEncoding,
1684    /// The ABI that the caller component is using to lift its export (the main
1685    /// entrypoint).
1686    pub caller_lift_abi: LiftAbi,
1687    /// The ABI that the callee component is using to lift its export (called
1688    /// by the caller).
1689    pub callee_lift_abi: LiftAbi,
1690    /// The ABI that the caller component is using to lower its import (the
1691    /// callee's export).
1692    pub caller_lower_abi: LowerAbi,
1693    /// The ABI that the callee component is using to lower its import (the
1694    /// host function).
1695    pub callee_lower_abi: LowerAbi,
1696}
1697
1698#[derive(Debug, Arbitrary, Copy, Clone)]
1699pub enum LiftAbi {
1700    Sync,
1701    AsyncStackful,
1702    AsyncCallback,
1703}
1704
1705#[derive(Debug, Arbitrary, Copy, Clone)]
1706pub enum LowerAbi {
1707    Sync,
1708    Async,
1709}
1710
1711impl<'a> TestCase<'a> {
1712    pub fn generate(types: &'a [Type], u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
1713        let max_params = if types.len() > 0 { 5 } else { 0 };
1714        let params = (0..u.int_in_range(0..=max_params)?)
1715            .map(|_| u.choose(&types))
1716            .collect::<arbitrary::Result<Vec<_>>>()?;
1717        let result = if types.len() > 0 && u.arbitrary()? {
1718            Some(u.choose(&types)?)
1719        } else {
1720            None
1721        };
1722
1723        let mut options = u.arbitrary::<TestCaseOptions>()?;
1724
1725        // Sync tasks cannot call async functions via a sync lower, nor can they
1726        // block in other ways (e.g. by calling `waitable-set.wait`, returning
1727        // `CALLBACK_CODE_WAIT`, etc.) prior to returning.  Therefore,
1728        // async-ness cascades to the callers:
1729        if options.host_async {
1730            options.guest_callee_async = true;
1731        }
1732        if options.guest_callee_async {
1733            options.guest_caller_async = true;
1734        }
1735
1736        Ok(Self {
1737            params,
1738            result,
1739            options,
1740        })
1741    }
1742
1743    /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case.
1744    pub fn declarations(&self) -> Declarations {
1745        let mut builder = TypesBuilder::default();
1746
1747        let mut params = String::new();
1748        for (i, ty) in self.params.iter().enumerate() {
1749            params.push_str(&format!(" (param \"p{i}\" "));
1750            builder.write_ref(ty, &mut params);
1751            params.push_str(")");
1752        }
1753
1754        let mut results = String::new();
1755        if let Some(ty) = self.result {
1756            results.push_str(&format!(" (result "));
1757            builder.write_ref(ty, &mut results);
1758            results.push_str(")");
1759        }
1760
1761        let caller_module = make_import_and_export(
1762            &self.params,
1763            self.result,
1764            self.options.caller_lift_abi,
1765            self.options.caller_lower_abi,
1766        );
1767        let callee_module = make_import_and_export(
1768            &self.params,
1769            self.result,
1770            self.options.callee_lift_abi,
1771            self.options.callee_lower_abi,
1772        );
1773
1774        let mut type_decls = Vec::new();
1775        let mut type_instantiation_args = String::new();
1776        while let Some((idx, ty)) = builder.worklist.pop() {
1777            type_decls.push(builder.write_decl(idx, ty));
1778            uwriteln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))");
1779        }
1780
1781        // Note that types are printed here in reverse order since they were
1782        // pushed onto `type_decls` as they were referenced meaning the last one
1783        // is the "base" one.
1784        let mut types = String::new();
1785        for decl in type_decls.into_iter().rev() {
1786            types.push_str(&decl);
1787            types.push_str("\n");
1788        }
1789
1790        Declarations {
1791            types: types.into(),
1792            type_instantiation_args: type_instantiation_args.into(),
1793            params: params.into(),
1794            results: results.into(),
1795            caller_module: caller_module.into(),
1796            callee_module: callee_module.into(),
1797            options: self.options,
1798        }
1799    }
1800}
1801
1802#[derive(Copy, Clone, Debug, Arbitrary)]
1803pub enum StringEncoding {
1804    Utf8,
1805    Utf16,
1806    Latin1OrUtf16,
1807}
1808
1809impl fmt::Display for StringEncoding {
1810    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1811        match self {
1812            StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f),
1813            StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f),
1814            StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f),
1815        }
1816    }
1817}
1818
1819impl ToTokens for TestCaseOptions {
1820    fn to_tokens(&self, tokens: &mut TokenStream) {
1821        let TestCaseOptions {
1822            guest_caller_async,
1823            guest_callee_async,
1824            host_async,
1825            caller_encoding,
1826            callee_encoding,
1827            caller_lift_abi,
1828            callee_lift_abi,
1829            caller_lower_abi,
1830            callee_lower_abi,
1831        } = self;
1832        tokens.extend(quote!(wasmtime_test_util::component_fuzz::TestCaseOptions {
1833            guest_caller_async: #guest_caller_async,
1834            guest_callee_async: #guest_callee_async,
1835            host_async: #host_async,
1836            caller_encoding: #caller_encoding,
1837            callee_encoding: #callee_encoding,
1838            caller_lift_abi: #caller_lift_abi,
1839            callee_lift_abi: #callee_lift_abi,
1840            caller_lower_abi: #caller_lower_abi,
1841            callee_lower_abi: #callee_lower_abi,
1842        }));
1843    }
1844}
1845
1846impl ToTokens for LowerAbi {
1847    fn to_tokens(&self, tokens: &mut TokenStream) {
1848        let me = match self {
1849            LowerAbi::Sync => quote!(Sync),
1850            LowerAbi::Async => quote!(Async),
1851        };
1852        tokens.extend(quote!(wasmtime_test_util::component_fuzz::LowerAbi::#me));
1853    }
1854}
1855
1856impl ToTokens for LiftAbi {
1857    fn to_tokens(&self, tokens: &mut TokenStream) {
1858        let me = match self {
1859            LiftAbi::Sync => quote!(Sync),
1860            LiftAbi::AsyncCallback => quote!(AsyncCallback),
1861            LiftAbi::AsyncStackful => quote!(AsyncStackful),
1862        };
1863        tokens.extend(quote!(wasmtime_test_util::component_fuzz::LiftAbi::#me));
1864    }
1865}
1866
1867impl ToTokens for StringEncoding {
1868    fn to_tokens(&self, tokens: &mut TokenStream) {
1869        let me = match self {
1870            StringEncoding::Utf8 => quote!(Utf8),
1871            StringEncoding::Utf16 => quote!(Utf16),
1872            StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16),
1873        };
1874        tokens.extend(quote!(wasmtime_test_util::component_fuzz::StringEncoding::#me));
1875    }
1876}
1877
1878#[cfg(test)]
1879mod tests {
1880    use super::*;
1881
1882    #[test]
1883    fn arbtest() {
1884        arbtest::arbtest(|u| {
1885            let mut fuel = 100;
1886            let types = (0..5)
1887                .map(|_| Type::generate(u, 3, &mut fuel))
1888                .collect::<arbitrary::Result<Vec<_>>>()?;
1889            let case = TestCase::generate(&types, u)?;
1890            let decls = case.declarations();
1891            let component = decls.make_component();
1892            let wasm = wat::parse_str(&component).unwrap_or_else(|e| {
1893                panic!("failed to parse generated component as wat: {e}\n\n{component}");
1894            });
1895            wasmparser::Validator::new_with_features(wasmparser::WasmFeatures::all())
1896                .validate_all(&wasm)
1897                .unwrap_or_else(|e| {
1898                    let mut wat = String::new();
1899                    let mut dst = wasmprinter::PrintFmtWrite(&mut wat);
1900                    let to_print = if wasmprinter::Config::new()
1901                        .print_offsets(true)
1902                        .print_operand_stack(true)
1903                        .print(&wasm, &mut dst)
1904                        .is_ok()
1905                    {
1906                        &wat[..]
1907                    } else {
1908                        &component[..]
1909                    };
1910                    panic!("generated component is not valid wasm: {e}\n\n{to_print}");
1911                });
1912            Ok(())
1913        })
1914        .budget_ms(1_000)
1915        // .seed(0x3c9050d4000000e9)
1916        ;
1917    }
1918}