wiggle_generate/types/
variant.rs1use crate::names;
2
3use proc_macro2::{Literal, TokenStream};
4use quote::quote;
5use witx::Layout;
6
7pub(super) fn define_variant(
8 name: &witx::Id,
9 v: &witx::Variant,
10 derive_std_error: bool,
11) -> TokenStream {
12 let ident = names::type_(name);
13 let size = v.mem_size_align().size as u32;
14 let align = v.mem_size_align().align;
15 let contents_offset = v.payload_offset() as u32;
16
17 let lifetime = quote!('a);
18 let tag_ty = super::int_repr_tokens(v.tag_repr);
19
20 let variants = v.cases.iter().map(|c| {
21 let var_name = names::enum_variant(&c.name);
22 if let Some(tref) = &c.tref {
23 let var_type = names::type_ref(&tref, lifetime.clone());
24 quote!(#var_name(#var_type))
25 } else {
26 quote!(#var_name)
27 }
28 });
29
30 let read_variant = v.cases.iter().enumerate().map(|(i, c)| {
31 let i = Literal::usize_unsuffixed(i);
32 let variantname = names::enum_variant(&c.name);
33 if let Some(tref) = &c.tref {
34 let varianttype = names::type_ref(tref, lifetime.clone());
35 quote! {
36 #i => {
37 let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
38 let variant_val = <#varianttype as wiggle::GuestType>::read(mem, variant_ptr.cast())?;
39 Ok(#ident::#variantname(variant_val))
40 }
41 }
42 } else {
43 quote! { #i => Ok(#ident::#variantname), }
44 }
45 });
46
47 let write_variant = v.cases.iter().enumerate().map(|(i, c)| {
48 let variantname = names::enum_variant(&c.name);
49 let write_tag = quote! {
50 mem.write(location.cast(), #i as #tag_ty)?;
51 };
52 if let Some(tref) = &c.tref {
53 let varianttype = names::type_ref(tref, lifetime.clone());
54 quote! {
55 #ident::#variantname(contents) => {
56 #write_tag
57 let variant_ptr = location.cast::<u8>().add(#contents_offset)?;
58 <#varianttype as wiggle::GuestType>::write(mem, variant_ptr.cast(), contents)?;
59 }
60 }
61 } else {
62 quote! {
63 #ident::#variantname => {
64 #write_tag
65 }
66 }
67 }
68 });
69
70 let mut extra_derive = quote!();
71 let enum_try_from = if v.cases.iter().all(|c| c.tref.is_none()) {
72 let tryfrom_repr_cases = v.cases.iter().enumerate().map(|(i, c)| {
73 let variant_name = names::enum_variant(&c.name);
74 let n = Literal::usize_unsuffixed(i);
75 quote!(#n => Ok(#ident::#variant_name))
76 });
77 let abi_ty = names::wasm_type(v.tag_repr.into());
78 extra_derive = quote!(, Copy);
79 quote! {
80 impl TryFrom<#tag_ty> for #ident {
81 type Error = wiggle::GuestError;
82 #[inline]
83 fn try_from(value: #tag_ty) -> Result<#ident, wiggle::GuestError> {
84 match value {
85 #(#tryfrom_repr_cases),*,
86 _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))),
87 }
88 }
89 }
90
91 impl TryFrom<#abi_ty> for #ident {
92 type Error = wiggle::GuestError;
93 #[inline]
94 fn try_from(value: #abi_ty) -> Result<#ident, wiggle::GuestError> {
95 #ident::try_from(#tag_ty::try_from(value)?)
96 }
97 }
98 }
99 } else {
100 quote!()
101 };
102
103 let enum_from = if v.cases.iter().all(|c| c.tref.is_none()) {
104 let from_repr_cases = v.cases.iter().enumerate().map(|(i, c)| {
105 let variant_name = names::enum_variant(&c.name);
106 let n = Literal::usize_unsuffixed(i);
107 quote!(#ident::#variant_name => #n)
108 });
109 quote! {
110 impl From<#ident> for #tag_ty {
111 #[inline]
112 fn from(v: #ident) -> #tag_ty {
113 match v {
114 #(#from_repr_cases),*,
115 }
116 }
117 }
118 }
119 } else {
120 quote!()
121 };
122
123 let extra_derive = quote!(, PartialEq #extra_derive);
124
125 let error_impls = if derive_std_error {
126 quote! {
127 impl std::fmt::Display for #ident {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write!(f, "{:?}", self)
130 }
131 }
132 impl std::error::Error for #ident {}
133 }
134 } else {
135 quote!()
136 };
137
138 quote! {
139 #[derive(Clone, Debug #extra_derive)]
140 pub enum #ident {
141 #(#variants),*
142 }
143 #error_impls
144
145 #enum_try_from
146 #enum_from
147
148 impl wiggle::GuestType for #ident {
149 #[inline]
150 fn guest_size() -> u32 {
151 #size
152 }
153
154 #[inline]
155 fn guest_align() -> usize {
156 #align
157 }
158
159 fn read(mem: &wiggle::GuestMemory, location: wiggle::GuestPtr<Self>)
160 -> Result<Self, wiggle::GuestError>
161 {
162 let tag = mem.read(location.cast::<#tag_ty>())?;
163 match tag {
164 #(#read_variant)*
165 _ => Err(wiggle::GuestError::InvalidEnumValue(stringify!(#ident))),
166 }
167
168 }
169
170 fn write(mem: &mut wiggle::GuestMemory, location: wiggle::GuestPtr<Self>, val: Self)
171 -> Result<(), wiggle::GuestError>
172 {
173 match val {
174 #(#write_variant)*
175 }
176 Ok(())
177 }
178 }
179 }
180}
181
182impl super::WiggleType for witx::Variant {
183 fn impls_display(&self) -> bool {
184 false
185 }
186}