wiggle_generate/types/
variant.rs

1use crate::names;
2
3use proc_macro2::{Literal, TokenStream};
4use quote::quote;
5use witx::Layout;
6
7pub(super) fn define_variant(
8    name: &witx::Id,
9    v: &witx::Variant,
10    derive_std_error: bool,
11) -> TokenStream {
12    let ident = names::type_(name);
13    let size = v.mem_size_align().size as u32;
14    let align = v.mem_size_align().align;
15    let contents_offset = v.payload_offset() as u32;
16
17    let lifetime = quote!('a);
18    let tag_ty = super::int_repr_tokens(v.tag_repr);
19
20    let variants = v.cases.iter().map(|c| {
21        let var_name = names::enum_variant(&c.name);
22        if let Some(tref) = &c.tref {
23            let var_type = names::type_ref(&tref, lifetime.clone());
24            quote!(#var_name(#var_type))
25        } else {
26            quote!(#var_name)
27        }
28    });
29
30    let read_variant = v.cases.iter().enumerate().map(|(i, c)| {
31        let i = Literal::usize_unsuffixed(i);
32        let variantname = names::enum_variant(&c.name);
33        if let Some(tref) = &c.tref {
34            let varianttype = names::type_ref(tref, lifetime.clone());
35            quote! {
36                #i => {
37                    let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
38                    let variant_val = <#varianttype as wiggle::GuestType>::read(mem, variant_ptr.cast())?;
39                    Ok(#ident::#variantname(variant_val))
40                }
41            }
42        } else {
43            quote! { #i => Ok(#ident::#variantname), }
44        }
45    });
46
47    let write_variant = v.cases.iter().enumerate().map(|(i, c)| {
48        let variantname = names::enum_variant(&c.name);
49        let write_tag = quote! {
50            mem.write(location.cast(), #i as #tag_ty)?;
51        };
52        if let Some(tref) = &c.tref {
53            let varianttype = names::type_ref(tref, lifetime.clone());
54            quote! {
55                #ident::#variantname(contents) => {
56                    #write_tag
57                    let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
58                    <#varianttype as wiggle::GuestType>::write(mem, variant_ptr.cast(), contents)?;
59                }
60            }
61        } else {
62            quote! {
63                #ident::#variantname => {
64                    #write_tag
65                }
66            }
67        }
68    });
69
70    let mut extra_derive = quote!();
71    let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) {
72        let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| {
73            let variant_name = names::enum_variant(&c.name);
74            let n = Literal::usize_unsuffixed(i);
75            quote!(#n => Ok(#ident::#variant_name))
76        });
77        let abi_ty = names::wasm_type(v.tag_repr.into());
78        extra_derive = quote!(, Copy);
79        quote! {
80            impl TryFrom<#tag_ty> for #ident {
81                type Error = wiggle::GuestError;
82                #[inline]
83                fn try_from(value: #tag_ty) -> Result<#ident, wiggle::GuestError> {
84                    match value {
85                        #(#tryfrom_repr_cases),*,
86                        _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))),
87                    }
88                }
89            }
90
91            impl TryFrom<#abi_ty> for #ident {
92                type Error = wiggle::GuestError;
93                #[inline]
94                fn try_from(value: #abi_ty) -> Result<#ident, wiggle::GuestError> {
95                    #ident::try_from(#tag_ty::try_from(value)?)
96                }
97            }
98        }
99    } else {
100        quote!()
101    };
102
103    let enum_from = if v.cases.iter().all(|c| c.tref.is_none()) {
104        let from_repr_cases = v.cases.iter().enumerate().map(|(i, c)| {
105            let variant_name = names::enum_variant(&c.name);
106            let n = Literal::usize_unsuffixed(i);
107            quote!(#ident::#variant_name => #n)
108        });
109        quote! {
110            impl From<#ident> for #tag_ty {
111                #[inline]
112                fn from(v: #ident) -> #tag_ty {
113                    match v {
114                        #(#from_repr_cases),*,
115                    }
116                }
117            }
118        }
119    } else {
120        quote!()
121    };
122
123    let extra_derive = quote!(, PartialEq #extra_derive);
124
125    let error_impls = if derive_std_error {
126        quote! {
127            impl std::fmt::Display for #ident {
128                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129                    write!(f, "{:?}", self)
130                }
131            }
132            impl std::error::Error for #ident {}
133        }
134    } else {
135        quote!()
136    };
137
138    quote! {
139        #[derive(Clone, Debug #extra_derive)]
140        pub enum #ident {
141            #(#variants),*
142        }
143        #error_impls
144
145        #enum_try_from
146        #enum_from
147
148        impl wiggle::GuestType for #ident {
149            #[inline]
150            fn guest_size() -> u32 {
151                #size
152            }
153
154            #[inline]
155            fn guest_align() -> usize {
156                #align
157            }
158
159            fn read(mem: &wiggle::GuestMemory, location: wiggle::GuestPtr<Self>)
160                -> Result<Self, wiggle::GuestError>
161            {
162                let tag = mem.read(location.cast::<#tag_ty>())?;
163                match tag {
164                    #(#read_variant)*
165                    _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))),
166                }
167
168            }
169
170            fn write(mem:  &mut wiggle::GuestMemory, location: wiggle::GuestPtr<Self>, val: Self)
171                -> Result<(), wiggle::GuestError>
172            {
173                match val {
174                    #(#write_variant)*
175                }
176                Ok(())
177            }
178        }
179    }
180}
181
182impl super::WiggleType for witx::Variant {
183    fn impls_display(&self) -> bool {
184        false
185    }
186}