component_fuzz_util/
lib.rs

1//! This module generates test cases for the Wasmtime component model function APIs,
2//! e.g. `wasmtime::component::func::Func` and `TypedFunc`.
3//!
4//! Each case includes a list of arbitrary interface types to use as parameters, plus another one to use as a
5//! result, and a component which exports a function and imports a function.  The exported function forwards its
6//! parameters to the imported one and forwards the result back to the caller.  This serves to exercise Wasmtime's
7//! lifting and lowering code and verify the values remain intact during both processes.
8
9use arbitrary::{Arbitrary, Unstructured};
10use proc_macro2::{Ident, TokenStream};
11use quote::{format_ident, quote, ToTokens};
12use std::borrow::Cow;
13use std::fmt::{self, Debug, Write};
14use std::iter;
15use std::ops::Deref;
16use wasmtime_component_util::{DiscriminantSize, FlagsSize, REALLOC_AND_FREE};
17
18const MAX_FLAT_PARAMS: usize = 16;
19const MAX_FLAT_RESULTS: usize = 1;
20
21/// The name of the imported host function which the generated component will call
22pub const IMPORT_FUNCTION: &str = "echo-import";
23
24/// The name of the exported guest function which the host should call
25pub const EXPORT_FUNCTION: &str = "echo-export";
26
27/// Wasmtime allows up to 100 type depth so limit this to just under that.
28pub const MAX_TYPE_DEPTH: u32 = 99;
29
30#[derive(Copy, Clone, PartialEq, Eq)]
31enum CoreType {
32    I32,
33    I64,
34    F32,
35    F64,
36}
37
38impl CoreType {
39    /// This is the `join` operation specified in [the canonical
40    /// ABI](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#flattening) for
41    /// variant types.
42    fn join(self, other: Self) -> Self {
43        match (self, other) {
44            _ if self == other => self,
45            (Self::I32, Self::F32) | (Self::F32, Self::I32) => Self::I32,
46            _ => Self::I64,
47        }
48    }
49}
50
51impl fmt::Display for CoreType {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self {
54            Self::I32 => f.write_str("i32"),
55            Self::I64 => f.write_str("i64"),
56            Self::F32 => f.write_str("f32"),
57            Self::F64 => f.write_str("f64"),
58        }
59    }
60}
61
62/// Wraps a `Box<[T]>` and provides an `Arbitrary` implementation that always generates slices of length less than
63/// or equal to the longest tuple for which Wasmtime generates a `ComponentType` impl
64#[derive(Debug, Clone)]
65pub struct VecInRange<T, const L: u32, const H: u32>(Vec<T>);
66
67impl<T, const L: u32, const H: u32> VecInRange<T, L, H> {
68    fn new<'a>(
69        input: &mut Unstructured<'a>,
70        fuel: &mut u32,
71        generate: impl Fn(&mut Unstructured<'a>, &mut u32) -> arbitrary::Result<T>,
72    ) -> arbitrary::Result<Self> {
73        let mut ret = Vec::new();
74        input.arbitrary_loop(Some(L), Some(H), |input| {
75            if *fuel > 0 {
76                *fuel = *fuel - 1;
77                ret.push(generate(input, fuel)?);
78                Ok(std::ops::ControlFlow::Continue(()))
79            } else {
80                Ok(std::ops::ControlFlow::Break(()))
81            }
82        })?;
83        Ok(Self(ret))
84    }
85}
86
87impl<T, const L: u32, const H: u32> Deref for VecInRange<T, L, H> {
88    type Target = [T];
89
90    fn deref(&self) -> &[T] {
91        self.0.deref()
92    }
93}
94
95/// Represents a component model interface type
96#[expect(missing_docs, reason = "self-describing")]
97#[derive(Debug, Clone)]
98pub enum Type {
99    Bool,
100    S8,
101    U8,
102    S16,
103    U16,
104    S32,
105    U32,
106    S64,
107    U64,
108    Float32,
109    Float64,
110    Char,
111    String,
112    List(Box<Type>),
113
114    // Give records the ability to generate a generous amount of fields but
115    // don't let the fuzzer go too wild since `wasmparser`'s validator currently
116    // has hard limits in the 1000-ish range on the number of fields a record
117    // may contain.
118    Record(VecInRange<Type, 1, 200>),
119
120    // Tuples can only have up to 16 type parameters in wasmtime right now for
121    // the static API, but the standard library only supports `Debug` up to 11
122    // elements, so compromise at an even 10.
123    Tuple(VecInRange<Type, 1, 10>),
124
125    // Like records, allow a good number of variants, but variants require at
126    // least one case.
127    Variant(VecInRange<Option<Type>, 1, 200>),
128    Enum(u32),
129
130    Option(Box<Type>),
131    Result {
132        ok: Option<Box<Type>>,
133        err: Option<Box<Type>>,
134    },
135
136    Flags(u32),
137}
138
139impl Type {
140    pub fn generate(
141        u: &mut Unstructured<'_>,
142        depth: u32,
143        fuel: &mut u32,
144    ) -> arbitrary::Result<Type> {
145        *fuel = fuel.saturating_sub(1);
146        let max = if depth == 0 || *fuel == 0 { 12 } else { 20 };
147        Ok(match u.int_in_range(0..=max)? {
148            0 => Type::Bool,
149            1 => Type::S8,
150            2 => Type::U8,
151            3 => Type::S16,
152            4 => Type::U16,
153            5 => Type::S32,
154            6 => Type::U32,
155            7 => Type::S64,
156            8 => Type::U64,
157            9 => Type::Float32,
158            10 => Type::Float64,
159            11 => Type::Char,
160            12 => Type::String,
161            // ^-- if you add something here update the `depth == 0` case above
162            13 => Type::List(Box::new(Type::generate(u, depth - 1, fuel)?)),
163            14 => Type::Record(Type::generate_list(u, depth - 1, fuel)?),
164            15 => Type::Tuple(Type::generate_list(u, depth - 1, fuel)?),
165            16 => Type::Variant(VecInRange::new(u, fuel, |u, fuel| {
166                Type::generate_opt(u, depth - 1, fuel)
167            })?),
168            17 => {
169                let amt = u.int_in_range(1..=(*fuel).max(1).min(257))?;
170                *fuel -= amt;
171                Type::Enum(amt)
172            }
173            18 => Type::Option(Box::new(Type::generate(u, depth - 1, fuel)?)),
174            19 => Type::Result {
175                ok: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
176                err: Type::generate_opt(u, depth - 1, fuel)?.map(Box::new),
177            },
178            20 => {
179                let amt = u.int_in_range(1..=(*fuel).min(32))?;
180                *fuel -= amt;
181                Type::Flags(amt)
182            }
183            // ^-- if you add something here update the `depth != 0` case above
184            _ => unreachable!(),
185        })
186    }
187
188    fn generate_opt(
189        u: &mut Unstructured<'_>,
190        depth: u32,
191        fuel: &mut u32,
192    ) -> arbitrary::Result<Option<Type>> {
193        Ok(if u.arbitrary()? {
194            Some(Type::generate(u, depth, fuel)?)
195        } else {
196            None
197        })
198    }
199
200    fn generate_list<const L: u32, const H: u32>(
201        u: &mut Unstructured<'_>,
202        depth: u32,
203        fuel: &mut u32,
204    ) -> arbitrary::Result<VecInRange<Type, L, H>> {
205        VecInRange::new(u, fuel, |u, fuel| Type::generate(u, depth, fuel))
206    }
207}
208
209fn lower_record<'a>(types: impl Iterator<Item = &'a Type>, vec: &mut Vec<CoreType>) {
210    for ty in types {
211        ty.lower(vec);
212    }
213}
214
215fn lower_variant<'a>(types: impl Iterator<Item = Option<&'a Type>>, vec: &mut Vec<CoreType>) {
216    vec.push(CoreType::I32);
217    let offset = vec.len();
218    for ty in types {
219        let ty = match ty {
220            Some(ty) => ty,
221            None => continue,
222        };
223        for (index, ty) in ty.lowered().iter().enumerate() {
224            let index = offset + index;
225            if index < vec.len() {
226                vec[index] = vec[index].join(*ty);
227            } else {
228                vec.push(*ty)
229            }
230        }
231    }
232}
233
234fn u32_count_from_flag_count(count: usize) -> usize {
235    match FlagsSize::from_count(count) {
236        FlagsSize::Size0 => 0,
237        FlagsSize::Size1 | FlagsSize::Size2 => 1,
238        FlagsSize::Size4Plus(n) => n.into(),
239    }
240}
241
242struct SizeAndAlignment {
243    size: usize,
244    alignment: u32,
245}
246
247impl Type {
248    fn lowered(&self) -> Vec<CoreType> {
249        let mut vec = Vec::new();
250        self.lower(&mut vec);
251        vec
252    }
253
254    fn lower(&self, vec: &mut Vec<CoreType>) {
255        match self {
256            Type::Bool
257            | Type::U8
258            | Type::S8
259            | Type::S16
260            | Type::U16
261            | Type::S32
262            | Type::U32
263            | Type::Char
264            | Type::Enum(_) => vec.push(CoreType::I32),
265            Type::S64 | Type::U64 => vec.push(CoreType::I64),
266            Type::Float32 => vec.push(CoreType::F32),
267            Type::Float64 => vec.push(CoreType::F64),
268            Type::String | Type::List(_) => {
269                vec.push(CoreType::I32);
270                vec.push(CoreType::I32);
271            }
272            Type::Record(types) => lower_record(types.iter(), vec),
273            Type::Tuple(types) => lower_record(types.0.iter(), vec),
274            Type::Variant(types) => lower_variant(types.0.iter().map(|t| t.as_ref()), vec),
275            Type::Option(ty) => lower_variant([None, Some(&**ty)].into_iter(), vec),
276            Type::Result { ok, err } => {
277                lower_variant([ok.as_deref(), err.as_deref()].into_iter(), vec)
278            }
279            Type::Flags(count) => vec.extend(
280                iter::repeat(CoreType::I32).take(u32_count_from_flag_count(*count as usize)),
281            ),
282        }
283    }
284
285    fn size_and_alignment(&self) -> SizeAndAlignment {
286        match self {
287            Type::Bool | Type::S8 | Type::U8 => SizeAndAlignment {
288                size: 1,
289                alignment: 1,
290            },
291
292            Type::S16 | Type::U16 => SizeAndAlignment {
293                size: 2,
294                alignment: 2,
295            },
296
297            Type::S32 | Type::U32 | Type::Char | Type::Float32 => SizeAndAlignment {
298                size: 4,
299                alignment: 4,
300            },
301
302            Type::S64 | Type::U64 | Type::Float64 => SizeAndAlignment {
303                size: 8,
304                alignment: 8,
305            },
306
307            Type::String | Type::List(_) => SizeAndAlignment {
308                size: 8,
309                alignment: 4,
310            },
311
312            Type::Record(types) => record_size_and_alignment(types.iter()),
313
314            Type::Tuple(types) => record_size_and_alignment(types.0.iter()),
315
316            Type::Variant(types) => variant_size_and_alignment(types.0.iter().map(|t| t.as_ref())),
317
318            Type::Enum(count) => variant_size_and_alignment((0..*count).map(|_| None)),
319
320            Type::Option(ty) => variant_size_and_alignment([None, Some(&**ty)].into_iter()),
321
322            Type::Result { ok, err } => {
323                variant_size_and_alignment([ok.as_deref(), err.as_deref()].into_iter())
324            }
325
326            Type::Flags(count) => match FlagsSize::from_count(*count as usize) {
327                FlagsSize::Size0 => SizeAndAlignment {
328                    size: 0,
329                    alignment: 1,
330                },
331                FlagsSize::Size1 => SizeAndAlignment {
332                    size: 1,
333                    alignment: 1,
334                },
335                FlagsSize::Size2 => SizeAndAlignment {
336                    size: 2,
337                    alignment: 2,
338                },
339                FlagsSize::Size4Plus(n) => SizeAndAlignment {
340                    size: usize::from(n) * 4,
341                    alignment: 4,
342                },
343            },
344        }
345    }
346}
347
348fn align_to(a: usize, align: u32) -> usize {
349    let align = align as usize;
350    (a + (align - 1)) & !(align - 1)
351}
352
353fn record_size_and_alignment<'a>(types: impl Iterator<Item = &'a Type>) -> SizeAndAlignment {
354    let mut offset = 0;
355    let mut align = 1;
356    for ty in types {
357        let SizeAndAlignment { size, alignment } = ty.size_and_alignment();
358        offset = align_to(offset, alignment) + size;
359        align = align.max(alignment);
360    }
361
362    SizeAndAlignment {
363        size: align_to(offset, align),
364        alignment: align,
365    }
366}
367
368fn variant_size_and_alignment<'a>(
369    types: impl ExactSizeIterator<Item = Option<&'a Type>>,
370) -> SizeAndAlignment {
371    let discriminant_size = DiscriminantSize::from_count(types.len()).unwrap();
372    let mut alignment = u32::from(discriminant_size);
373    let mut size = 0;
374    for ty in types {
375        if let Some(ty) = ty {
376            let size_and_alignment = ty.size_and_alignment();
377            alignment = alignment.max(size_and_alignment.alignment);
378            size = size.max(size_and_alignment.size);
379        }
380    }
381
382    SizeAndAlignment {
383        size: align_to(
384            align_to(usize::from(discriminant_size), alignment) + size,
385            alignment,
386        ),
387        alignment,
388    }
389}
390
391fn make_import_and_export(params: &[&Type], result: Option<&Type>) -> String {
392    let params_lowered = params
393        .iter()
394        .flat_map(|ty| ty.lowered())
395        .collect::<Box<[_]>>();
396    let result_lowered = result.map(|t| t.lowered()).unwrap_or(Vec::new());
397
398    let mut core_params = String::new();
399    let mut gets = String::new();
400
401    if params_lowered.len() <= MAX_FLAT_PARAMS {
402        for (index, param) in params_lowered.iter().enumerate() {
403            write!(&mut core_params, " {param}").unwrap();
404            write!(&mut gets, "local.get {index} ").unwrap();
405        }
406    } else {
407        write!(&mut core_params, " i32").unwrap();
408        write!(&mut gets, "local.get 0 ").unwrap();
409    }
410
411    let maybe_core_params = if params_lowered.is_empty() {
412        String::new()
413    } else {
414        format!("(param{core_params})")
415    };
416
417    if result_lowered.len() <= MAX_FLAT_RESULTS {
418        let mut core_results = String::new();
419        for result in result_lowered.iter() {
420            write!(&mut core_results, " {result}").unwrap();
421        }
422
423        let maybe_core_results = if result_lowered.is_empty() {
424            String::new()
425        } else {
426            format!("(result{core_results})")
427        };
428
429        format!(
430            r#"
431            (func $f (import "host" "{IMPORT_FUNCTION}") {maybe_core_params} {maybe_core_results})
432
433            (func (export "{EXPORT_FUNCTION}") {maybe_core_params} {maybe_core_results}
434                {gets}
435
436                call $f
437            )"#
438        )
439    } else {
440        let SizeAndAlignment { size, alignment } = result.unwrap().size_and_alignment();
441
442        format!(
443            r#"
444            (func $f (import "host" "{IMPORT_FUNCTION}") (param{core_params} i32))
445
446            (func (export "{EXPORT_FUNCTION}") {maybe_core_params} (result i32)
447                (local $base i32)
448                (local.set $base
449                    (call $realloc
450                        (i32.const 0)
451                        (i32.const 0)
452                        (i32.const {alignment})
453                        (i32.const {size})))
454                {gets}
455                local.get $base
456
457                call $f
458
459                local.get $base
460            )"#
461        )
462    }
463}
464
465fn make_rust_name(name_counter: &mut u32) -> Ident {
466    let name = format_ident!("Foo{name_counter}");
467    *name_counter += 1;
468    name
469}
470
471/// Generate a [`TokenStream`] containing the rust type name for a type.
472///
473/// The `name_counter` parameter is used to generate names for each recursively visited type.  The `declarations`
474/// parameter is used to accumulate declarations for each recursively visited type.
475pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStream) -> TokenStream {
476    match ty {
477        Type::Bool => quote!(bool),
478        Type::S8 => quote!(i8),
479        Type::U8 => quote!(u8),
480        Type::S16 => quote!(i16),
481        Type::U16 => quote!(u16),
482        Type::S32 => quote!(i32),
483        Type::U32 => quote!(u32),
484        Type::S64 => quote!(i64),
485        Type::U64 => quote!(u64),
486        Type::Float32 => quote!(Float32),
487        Type::Float64 => quote!(Float64),
488        Type::Char => quote!(char),
489        Type::String => quote!(Box<str>),
490        Type::List(ty) => {
491            let ty = rust_type(ty, name_counter, declarations);
492            quote!(Vec<#ty>)
493        }
494        Type::Record(types) => {
495            let fields = types
496                .iter()
497                .enumerate()
498                .map(|(index, ty)| {
499                    let name = format_ident!("f{index}");
500                    let ty = rust_type(ty, name_counter, declarations);
501                    quote!(#name: #ty,)
502                })
503                .collect::<TokenStream>();
504
505            let name = make_rust_name(name_counter);
506
507            declarations.extend(quote! {
508                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
509                #[component(record)]
510                struct #name {
511                    #fields
512                }
513            });
514
515            quote!(#name)
516        }
517        Type::Tuple(types) => {
518            let fields = types
519                .0
520                .iter()
521                .map(|ty| {
522                    let ty = rust_type(ty, name_counter, declarations);
523                    quote!(#ty,)
524                })
525                .collect::<TokenStream>();
526
527            quote!((#fields))
528        }
529        Type::Variant(types) => {
530            let cases = types
531                .0
532                .iter()
533                .enumerate()
534                .map(|(index, ty)| {
535                    let name = format_ident!("C{index}");
536                    let ty = match ty {
537                        Some(ty) => {
538                            let ty = rust_type(ty, name_counter, declarations);
539                            quote!((#ty))
540                        }
541                        None => quote!(),
542                    };
543                    quote!(#name #ty,)
544                })
545                .collect::<TokenStream>();
546
547            let name = make_rust_name(name_counter);
548            declarations.extend(quote! {
549                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)]
550                #[component(variant)]
551                enum #name {
552                    #cases
553                }
554            });
555
556            quote!(#name)
557        }
558        Type::Enum(count) => {
559            let cases = (0..*count)
560                .map(|index| {
561                    let name = format_ident!("E{index}");
562                    quote!(#name,)
563                })
564                .collect::<TokenStream>();
565
566            let name = make_rust_name(name_counter);
567            let repr = match count.ilog2() {
568                0..=7 => quote!(u8),
569                8..=15 => quote!(u16),
570                _ => quote!(u32),
571            };
572
573            declarations.extend(quote! {
574                #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)]
575                #[component(enum)]
576                #[repr(#repr)]
577                enum #name {
578                    #cases
579                }
580            });
581
582            quote!(#name)
583        }
584        Type::Option(ty) => {
585            let ty = rust_type(ty, name_counter, declarations);
586            quote!(Option<#ty>)
587        }
588        Type::Result { ok, err } => {
589            let ok = match ok {
590                Some(ok) => rust_type(ok, name_counter, declarations),
591                None => quote!(()),
592            };
593            let err = match err {
594                Some(err) => rust_type(err, name_counter, declarations),
595                None => quote!(()),
596            };
597            quote!(Result<#ok, #err>)
598        }
599        Type::Flags(count) => {
600            let type_name = make_rust_name(name_counter);
601
602            let mut flags = TokenStream::new();
603            let mut names = TokenStream::new();
604
605            for index in 0..*count {
606                let name = format_ident!("F{index}");
607                flags.extend(quote!(const #name;));
608                names.extend(quote!(#type_name::#name,))
609            }
610
611            declarations.extend(quote! {
612                wasmtime::component::flags! {
613                    #type_name {
614                        #flags
615                    }
616                }
617
618                impl<'a> arbitrary::Arbitrary<'a> for #type_name {
619                    fn arbitrary(input: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
620                        let mut flags = #type_name::default();
621                        for flag in [#names] {
622                            if input.arbitrary()? {
623                                flags |= flag;
624                            }
625                        }
626                        Ok(flags)
627                    }
628                }
629            });
630
631            quote!(#type_name)
632        }
633    }
634}
635
636#[derive(Default)]
637struct TypesBuilder<'a> {
638    next: u32,
639    worklist: Vec<(u32, &'a Type)>,
640}
641
642impl<'a> TypesBuilder<'a> {
643    fn write_ref(&mut self, ty: &'a Type, dst: &mut String) {
644        match ty {
645            // Primitive types can be referenced directly
646            Type::Bool => dst.push_str("bool"),
647            Type::S8 => dst.push_str("s8"),
648            Type::U8 => dst.push_str("u8"),
649            Type::S16 => dst.push_str("s16"),
650            Type::U16 => dst.push_str("u16"),
651            Type::S32 => dst.push_str("s32"),
652            Type::U32 => dst.push_str("u32"),
653            Type::S64 => dst.push_str("s64"),
654            Type::U64 => dst.push_str("u64"),
655            Type::Float32 => dst.push_str("float32"),
656            Type::Float64 => dst.push_str("float64"),
657            Type::Char => dst.push_str("char"),
658            Type::String => dst.push_str("string"),
659
660            // Otherwise emit a reference to the type and remember to generate
661            // the corresponding type alias later.
662            Type::List(_)
663            | Type::Record(_)
664            | Type::Tuple(_)
665            | Type::Variant(_)
666            | Type::Enum(_)
667            | Type::Option(_)
668            | Type::Result { .. }
669            | Type::Flags(_) => {
670                let idx = self.next;
671                self.next += 1;
672                write!(dst, "$t{idx}").unwrap();
673                self.worklist.push((idx, ty));
674            }
675        }
676    }
677
678    fn write_decl(&mut self, idx: u32, ty: &'a Type) -> String {
679        let mut decl = format!("(type $t{idx}' ");
680        match ty {
681            Type::Bool
682            | Type::S8
683            | Type::U8
684            | Type::S16
685            | Type::U16
686            | Type::S32
687            | Type::U32
688            | Type::S64
689            | Type::U64
690            | Type::Float32
691            | Type::Float64
692            | Type::Char
693            | Type::String => unreachable!(),
694
695            Type::List(ty) => {
696                decl.push_str("(list ");
697                self.write_ref(ty, &mut decl);
698                decl.push_str(")");
699            }
700            Type::Record(types) => {
701                decl.push_str("(record");
702                for (index, ty) in types.iter().enumerate() {
703                    write!(decl, r#" (field "f{index}" "#).unwrap();
704                    self.write_ref(ty, &mut decl);
705                    decl.push_str(")");
706                }
707                decl.push_str(")");
708            }
709            Type::Tuple(types) => {
710                decl.push_str("(tuple");
711                for ty in types.iter() {
712                    decl.push_str(" ");
713                    self.write_ref(ty, &mut decl);
714                }
715                decl.push_str(")");
716            }
717            Type::Variant(types) => {
718                decl.push_str("(variant");
719                for (index, ty) in types.iter().enumerate() {
720                    write!(decl, r#" (case "C{index}""#).unwrap();
721                    if let Some(ty) = ty {
722                        decl.push_str(" ");
723                        self.write_ref(ty, &mut decl);
724                    }
725                    decl.push_str(")");
726                }
727                decl.push_str(")");
728            }
729            Type::Enum(count) => {
730                decl.push_str("(enum");
731                for index in 0..*count {
732                    write!(decl, r#" "E{index}""#).unwrap();
733                }
734                decl.push_str(")");
735            }
736            Type::Option(ty) => {
737                decl.push_str("(option ");
738                self.write_ref(ty, &mut decl);
739                decl.push_str(")");
740            }
741            Type::Result { ok, err } => {
742                decl.push_str("(result");
743                if let Some(ok) = ok {
744                    decl.push_str(" ");
745                    self.write_ref(ok, &mut decl);
746                }
747                if let Some(err) = err {
748                    decl.push_str(" (error ");
749                    self.write_ref(err, &mut decl);
750                    decl.push_str(")");
751                }
752                decl.push_str(")");
753            }
754            Type::Flags(count) => {
755                decl.push_str("(flags");
756                for index in 0..*count {
757                    write!(decl, r#" "F{index}""#).unwrap();
758                }
759                decl.push_str(")");
760            }
761        }
762        decl.push_str(")\n");
763        writeln!(decl, "(import \"t{idx}\" (type $t{idx} (eq $t{idx}')))").unwrap();
764        decl
765    }
766}
767
768/// Represents custom fragments of a WAT file which may be used to create a component for exercising [`TestCase`]s
769#[derive(Debug)]
770pub struct Declarations {
771    /// Type declarations (if any) referenced by `params` and/or `result`
772    pub types: Cow<'static, str>,
773    /// Types to thread through when instantiating sub-components.
774    pub type_instantiation_args: Cow<'static, str>,
775    /// Parameter declarations used for the imported and exported functions
776    pub params: Cow<'static, str>,
777    /// Result declaration used for the imported and exported functions
778    pub results: Cow<'static, str>,
779    /// A WAT fragment representing the core function import and export to use for testing
780    pub import_and_export: Cow<'static, str>,
781    /// String encoding to use for host -> component
782    pub encoding1: StringEncoding,
783    /// String encoding to use for component -> host
784    pub encoding2: StringEncoding,
785}
786
787impl Declarations {
788    /// Generate a complete WAT file based on the specified fragments.
789    pub fn make_component(&self) -> Box<str> {
790        let Self {
791            types,
792            type_instantiation_args,
793            params,
794            results,
795            import_and_export,
796            encoding1,
797            encoding2,
798        } = self;
799        let mk_component = |name: &str, encoding: StringEncoding| {
800            format!(
801                r#"
802                (component ${name}
803                    {types}
804                    (type $sig (func {params} {results}))
805                    (import "{IMPORT_FUNCTION}" (func $f (type $sig)))
806
807                    (core instance $libc (instantiate $libc))
808
809                    (core func $f_lower (canon lower
810                        (func $f)
811                        (memory $libc "memory")
812                        (realloc (func $libc "realloc"))
813                        string-encoding={encoding}
814                    ))
815
816                    (core instance $i (instantiate $m
817                        (with "libc" (instance $libc))
818                        (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower))))
819                    ))
820
821                    (func (export "{EXPORT_FUNCTION}") (type $sig)
822                        (canon lift
823                            (core func $i "{EXPORT_FUNCTION}")
824                            (memory $libc "memory")
825                            (realloc (func $libc "realloc"))
826                            string-encoding={encoding}
827                        )
828                    )
829                )
830            "#
831            )
832        };
833
834        let c1 = mk_component("c1", *encoding2);
835        let c2 = mk_component("c2", *encoding1);
836
837        format!(
838            r#"
839            (component
840                (core module $libc
841                    (memory (export "memory") 1)
842                    {REALLOC_AND_FREE}
843                )
844
845                (core module $m
846                    (memory (import "libc" "memory") 1)
847                    (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32))
848
849                    {import_and_export}
850                )
851
852                {types}
853
854                (type $sig (func {params} {results}))
855                (import "{IMPORT_FUNCTION}" (func $f (type $sig)))
856
857                {c1}
858                {c2}
859                (instance $c1 (instantiate $c1
860                    {type_instantiation_args}
861                    (with "{IMPORT_FUNCTION}" (func $f))
862                ))
863                (instance $c2 (instantiate $c2
864                    {type_instantiation_args}
865                    (with "{IMPORT_FUNCTION}" (func $c1 "{EXPORT_FUNCTION}"))
866                ))
867                (export "{EXPORT_FUNCTION}" (func $c2 "{EXPORT_FUNCTION}"))
868            )"#,
869        )
870        .into()
871    }
872}
873
874/// Represents a test case for calling a component function
875#[derive(Debug)]
876pub struct TestCase<'a> {
877    /// The types of parameters to pass to the function
878    pub params: Vec<&'a Type>,
879    /// The result types of the function
880    pub result: Option<&'a Type>,
881    /// String encoding to use from host-to-component.
882    pub encoding1: StringEncoding,
883    /// String encoding to use from component-to-host.
884    pub encoding2: StringEncoding,
885}
886
887impl TestCase<'_> {
888    /// Generate a `Declarations` for this `TestCase` which may be used to build a component to execute the case.
889    pub fn declarations(&self) -> Declarations {
890        let mut builder = TypesBuilder::default();
891
892        let mut params = String::new();
893        for (i, ty) in self.params.iter().enumerate() {
894            params.push_str(&format!(" (param \"p{i}\" "));
895            builder.write_ref(ty, &mut params);
896            params.push_str(")");
897        }
898
899        let mut results = String::new();
900        if let Some(ty) = self.result {
901            results.push_str(&format!(" (result "));
902            builder.write_ref(ty, &mut results);
903            results.push_str(")");
904        }
905
906        let import_and_export = make_import_and_export(&self.params, self.result);
907
908        let mut type_decls = Vec::new();
909        let mut type_instantiation_args = String::new();
910        while let Some((idx, ty)) = builder.worklist.pop() {
911            type_decls.push(builder.write_decl(idx, ty));
912            writeln!(type_instantiation_args, "(with \"t{idx}\" (type $t{idx}))").unwrap();
913        }
914
915        // Note that types are printed here in reverse order since they were
916        // pushed onto `type_decls` as they were referenced meaning the last one
917        // is the "base" one.
918        let mut types = String::new();
919        for decl in type_decls.into_iter().rev() {
920            types.push_str(&decl);
921            types.push_str("\n");
922        }
923
924        Declarations {
925            types: types.into(),
926            type_instantiation_args: type_instantiation_args.into(),
927            params: params.into(),
928            results: results.into(),
929            import_and_export: import_and_export.into(),
930            encoding1: self.encoding1,
931            encoding2: self.encoding2,
932        }
933    }
934}
935
936#[derive(Copy, Clone, Debug, Arbitrary)]
937pub enum StringEncoding {
938    Utf8,
939    Utf16,
940    Latin1OrUtf16,
941}
942
943impl fmt::Display for StringEncoding {
944    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
945        match self {
946            StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f),
947            StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f),
948            StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f),
949        }
950    }
951}
952
953impl ToTokens for StringEncoding {
954    fn to_tokens(&self, tokens: &mut TokenStream) {
955        let me = match self {
956            StringEncoding::Utf8 => quote!(Utf8),
957            StringEncoding::Utf16 => quote!(Utf16),
958            StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16),
959        };
960        tokens.extend(quote!(component_fuzz_util::StringEncoding::#me));
961    }
962}