1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::collections::HashSet;
4use std::fmt;
5use syn::parse::{Parse, ParseStream};
6use syn::punctuated::Punctuated;
7use syn::{braced, parse_quote, Data, DeriveInput, Error, Ident, Result, Token};
8use wasmtime_component_util::{DiscriminantSize, FlagsSize};
9
10mod kw {
11 syn::custom_keyword!(record);
12 syn::custom_keyword!(variant);
13 syn::custom_keyword!(flags);
14 syn::custom_keyword!(name);
15 syn::custom_keyword!(wasmtime_crate);
16}
17
18#[derive(Debug, Copy, Clone)]
19enum Style {
20 Record,
21 Enum,
22 Variant,
23}
24
25impl fmt::Display for Style {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Style::Record => f.write_str("record"),
29 Style::Enum => f.write_str("enum"),
30 Style::Variant => f.write_str("variant"),
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36enum ComponentAttr {
37 Style(Style),
38 WasmtimeCrate(syn::Path),
39}
40
41impl Parse for ComponentAttr {
42 fn parse(input: ParseStream) -> Result<Self> {
43 let lookahead = input.lookahead1();
44 if lookahead.peek(kw::record) {
45 input.parse::<kw::record>()?;
46 Ok(ComponentAttr::Style(Style::Record))
47 } else if lookahead.peek(kw::variant) {
48 input.parse::<kw::variant>()?;
49 Ok(ComponentAttr::Style(Style::Variant))
50 } else if lookahead.peek(Token![enum]) {
51 input.parse::<Token![enum]>()?;
52 Ok(ComponentAttr::Style(Style::Enum))
53 } else if lookahead.peek(kw::wasmtime_crate) {
54 input.parse::<kw::wasmtime_crate>()?;
55 input.parse::<Token![=]>()?;
56 Ok(ComponentAttr::WasmtimeCrate(input.parse()?))
57 } else if input.peek(kw::flags) {
58 Err(input.error(
59 "`flags` not allowed here; \
60 use `wasmtime::component::flags!` macro to define `flags` types",
61 ))
62 } else {
63 Err(lookahead.error())
64 }
65 }
66}
67
68fn find_rename(attributes: &[syn::Attribute]) -> Result<Option<syn::LitStr>> {
69 let mut name = None;
70
71 for attribute in attributes {
72 if !attribute.path().is_ident("component") {
73 continue;
74 }
75 let name_literal = attribute.parse_args_with(|parser: ParseStream<'_>| {
76 parser.parse::<kw::name>()?;
77 parser.parse::<Token![=]>()?;
78 parser.parse::<syn::LitStr>()
79 })?;
80
81 if name.is_some() {
82 return Err(Error::new_spanned(
83 attribute,
84 "duplicate field rename attribute",
85 ));
86 }
87
88 name = Some(name_literal);
89 }
90
91 Ok(name)
92}
93
94fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn::Generics {
95 let mut generics = generics.clone();
96 for param in &mut generics.params {
97 if let syn::GenericParam::Type(ref mut type_param) = *param {
98 type_param.bounds.push(bound.clone());
99 }
100 }
101 generics
102}
103
104pub struct VariantCase<'a> {
105 attrs: &'a [syn::Attribute],
106 ident: &'a syn::Ident,
107 ty: Option<&'a syn::Type>,
108}
109
110pub trait Expander {
111 fn expand_record(
112 &self,
113 name: &syn::Ident,
114 generics: &syn::Generics,
115 fields: &[&syn::Field],
116 wasmtime_crate: &syn::Path,
117 ) -> Result<TokenStream>;
118
119 fn expand_variant(
120 &self,
121 name: &syn::Ident,
122 generics: &syn::Generics,
123 discriminant_size: DiscriminantSize,
124 cases: &[VariantCase],
125 wasmtime_crate: &syn::Path,
126 ) -> Result<TokenStream>;
127
128 fn expand_enum(
129 &self,
130 name: &syn::Ident,
131 discriminant_size: DiscriminantSize,
132 cases: &[VariantCase],
133 wasmtime_crate: &syn::Path,
134 ) -> Result<TokenStream>;
135}
136
137pub fn expand(expander: &dyn Expander, input: &DeriveInput) -> Result<TokenStream> {
138 let mut wasmtime_crate = None;
139 let mut style = None;
140
141 for attribute in &input.attrs {
142 if !attribute.path().is_ident("component") {
143 continue;
144 }
145 match attribute.parse_args()? {
146 ComponentAttr::WasmtimeCrate(c) => wasmtime_crate = Some(c),
147 ComponentAttr::Style(attr_style) => {
148 if style.is_some() {
149 return Err(Error::new_spanned(
150 attribute,
151 "duplicate `component` attribute",
152 ));
153 }
154 style = Some(attr_style);
155 }
156 }
157 }
158
159 let style = style.ok_or_else(|| Error::new_spanned(input, "missing `component` attribute"))?;
160 let wasmtime_crate = wasmtime_crate.unwrap_or_else(default_wasmtime_crate);
161 match style {
162 Style::Record => expand_record(expander, input, &wasmtime_crate),
163 Style::Enum | Style::Variant => expand_variant(expander, input, style, &wasmtime_crate),
164 }
165}
166
167fn default_wasmtime_crate() -> syn::Path {
168 Ident::new("wasmtime", Span::call_site()).into()
169}
170
171fn expand_record(
172 expander: &dyn Expander,
173 input: &DeriveInput,
174 wasmtime_crate: &syn::Path,
175) -> Result<TokenStream> {
176 let name = &input.ident;
177
178 let body = if let Data::Struct(body) = &input.data {
179 body
180 } else {
181 return Err(Error::new(
182 name.span(),
183 "`record` component types can only be derived for Rust `struct`s",
184 ));
185 };
186
187 match &body.fields {
188 syn::Fields::Named(fields) => expander.expand_record(
189 &input.ident,
190 &input.generics,
191 &fields.named.iter().collect::<Vec<_>>(),
192 wasmtime_crate,
193 ),
194
195 syn::Fields::Unnamed(_) | syn::Fields::Unit => Err(Error::new(
196 name.span(),
197 "`record` component types can only be derived for `struct`s with named fields",
198 )),
199 }
200}
201
202fn expand_variant(
203 expander: &dyn Expander,
204 input: &DeriveInput,
205 style: Style,
206 wasmtime_crate: &syn::Path,
207) -> Result<TokenStream> {
208 let name = &input.ident;
209
210 let body = if let Data::Enum(body) = &input.data {
211 body
212 } else {
213 return Err(Error::new(
214 name.span(),
215 format!("`{style}` component types can only be derived for Rust `enum`s"),
216 ));
217 };
218
219 if body.variants.is_empty() {
220 return Err(Error::new(
221 name.span(),
222 format!("`{style}` component types can only be derived for Rust `enum`s with at least one variant"),
223 ));
224 }
225
226 let discriminant_size = DiscriminantSize::from_count(body.variants.len()).ok_or_else(|| {
227 Error::new(
228 input.ident.span(),
229 "`enum`s with more than 2^32 variants are not supported",
230 )
231 })?;
232
233 let cases = body
234 .variants
235 .iter()
236 .map(
237 |syn::Variant {
238 attrs,
239 ident,
240 fields,
241 ..
242 }| {
243 Ok(VariantCase {
244 attrs,
245 ident,
246 ty: match fields {
247 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
248 Some(&fields.unnamed[0].ty)
249 }
250 syn::Fields::Unit => None,
251 _ => {
252 return Err(Error::new(
253 name.span(),
254 format!(
255 "`{}` component types can only be derived for Rust `enum`s \
256 containing variants with {}",
257 style,
258 match style {
259 Style::Variant => "at most one unnamed field each",
260 Style::Enum => "no fields",
261 Style::Record => unreachable!(),
262 }
263 ),
264 ))
265 }
266 },
267 })
268 },
269 )
270 .collect::<Result<Vec<_>>>()?;
271
272 match style {
273 Style::Variant => expander.expand_variant(
274 &input.ident,
275 &input.generics,
276 discriminant_size,
277 &cases,
278 wasmtime_crate,
279 ),
280 Style::Enum => {
281 validate_enum(input, &body, discriminant_size)?;
282 expander.expand_enum(&input.ident, discriminant_size, &cases, wasmtime_crate)
283 }
284 Style::Record => unreachable!(),
285 }
286}
287
288fn validate_enum(input: &DeriveInput, body: &syn::DataEnum, size: DiscriminantSize) -> Result<()> {
292 if !input.generics.params.is_empty() {
293 return Err(Error::new_spanned(
294 &input.generics.params,
295 "cannot have generics on an `enum`",
296 ));
297 }
298 if let Some(clause) = &input.generics.where_clause {
299 return Err(Error::new_spanned(
300 clause,
301 "cannot have a where clause on an `enum`",
302 ));
303 }
304 let expected_discr = match size {
305 DiscriminantSize::Size1 => "u8",
306 DiscriminantSize::Size2 => "u16",
307 DiscriminantSize::Size4 => "u32",
308 };
309 let mut found_repr = false;
310 for attr in input.attrs.iter() {
311 if !attr.meta.path().is_ident("repr") {
312 continue;
313 }
314 let list = attr.meta.require_list()?;
315 found_repr = true;
316 if list.tokens.to_string() != expected_discr {
317 return Err(Error::new_spanned(
318 &list.tokens,
319 format!(
320 "expected `repr({expected_discr})`, found `repr({})`",
321 list.tokens
322 ),
323 ));
324 }
325 }
326 if !found_repr {
327 return Err(Error::new_spanned(
328 &body.enum_token,
329 format!("missing required `#[repr({expected_discr})]`"),
330 ));
331 }
332
333 for case in body.variants.iter() {
334 if let Some((_, expr)) = &case.discriminant {
335 return Err(Error::new_spanned(
336 expr,
337 "cannot have an explicit discriminant",
338 ));
339 }
340 }
341
342 Ok(())
343}
344
345fn expand_record_for_component_type(
346 name: &syn::Ident,
347 generics: &syn::Generics,
348 fields: &[&syn::Field],
349 typecheck: TokenStream,
350 typecheck_argument: TokenStream,
351 wt: &syn::Path,
352) -> Result<TokenStream> {
353 let internal = quote!(#wt::component::__internal);
354
355 let mut lower_generic_params = TokenStream::new();
356 let mut lower_generic_args = TokenStream::new();
357 let mut lower_field_declarations = TokenStream::new();
358 let mut abi_list = TokenStream::new();
359 let mut unique_types = HashSet::new();
360
361 for (index, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {
362 let generic = format_ident!("T{}", index);
363
364 lower_generic_params.extend(quote!(#generic: Copy,));
365 lower_generic_args.extend(quote!(<#ty as #wt::component::ComponentType>::Lower,));
366
367 lower_field_declarations.extend(quote!(#ident: #generic,));
368
369 abi_list.extend(quote!(
370 <#ty as #wt::component::ComponentType>::ABI,
371 ));
372
373 unique_types.insert(ty);
374 }
375
376 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::ComponentType));
377 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
378 let lower = format_ident!("Lower{}", name);
379
380 let expanded = quote! {
396 #[doc(hidden)]
397 #[derive(Clone, Copy)]
398 #[repr(C)]
399 pub struct #lower <#lower_generic_params> {
400 #lower_field_declarations
401 _align: [#wt::ValRaw; 0],
402 }
403
404 unsafe impl #impl_generics #wt::component::ComponentType for #name #ty_generics #where_clause {
405 type Lower = #lower <#lower_generic_args>;
406
407 const ABI: #internal::CanonicalAbiInfo =
408 #internal::CanonicalAbiInfo::record_static(&[#abi_list]);
409
410 #[inline]
411 fn typecheck(
412 ty: &#internal::InterfaceType,
413 types: &#internal::InstanceType<'_>,
414 ) -> #internal::anyhow::Result<()> {
415 #internal::#typecheck(ty, types, &[#typecheck_argument])
416 }
417 }
418 };
419
420 Ok(quote!(const _: () = { #expanded };))
421}
422
423fn quote(size: DiscriminantSize, discriminant: usize) -> TokenStream {
424 match size {
425 DiscriminantSize::Size1 => {
426 let discriminant = u8::try_from(discriminant).unwrap();
427 quote!(#discriminant)
428 }
429 DiscriminantSize::Size2 => {
430 let discriminant = u16::try_from(discriminant).unwrap();
431 quote!(#discriminant)
432 }
433 DiscriminantSize::Size4 => {
434 let discriminant = u32::try_from(discriminant).unwrap();
435 quote!(#discriminant)
436 }
437 }
438}
439
440pub struct LiftExpander;
441
442impl Expander for LiftExpander {
443 fn expand_record(
444 &self,
445 name: &syn::Ident,
446 generics: &syn::Generics,
447 fields: &[&syn::Field],
448 wt: &syn::Path,
449 ) -> Result<TokenStream> {
450 let internal = quote!(#wt::component::__internal);
451
452 let mut lifts = TokenStream::new();
453 let mut loads = TokenStream::new();
454
455 for (i, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {
456 let field_ty = quote!(ty.fields[#i].ty);
457 lifts.extend(quote!(#ident: <#ty as #wt::component::Lift>::lift(
458 cx, #field_ty, &src.#ident
459 )?,));
460
461 loads.extend(quote!(#ident: <#ty as #wt::component::Lift>::load(
462 cx, #field_ty,
463 &bytes
464 [<#ty as #wt::component::ComponentType>::ABI.next_field32_size(&mut offset)..]
465 [..<#ty as #wt::component::ComponentType>::SIZE32]
466 )?,));
467 }
468
469 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lift));
470 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
471
472 let extract_ty = quote! {
473 let ty = match ty {
474 #internal::InterfaceType::Record(i) => &cx.types[i],
475 _ => #internal::bad_type_info(),
476 };
477 };
478
479 let expanded = quote! {
480 unsafe impl #impl_generics #wt::component::Lift for #name #ty_generics #where_clause {
481 #[inline]
482 fn lift(
483 cx: &mut #internal::LiftContext<'_>,
484 ty: #internal::InterfaceType,
485 src: &Self::Lower,
486 ) -> #internal::anyhow::Result<Self> {
487 #extract_ty
488 Ok(Self {
489 #lifts
490 })
491 }
492
493 #[inline]
494 fn load(
495 cx: &mut #internal::LiftContext<'_>,
496 ty: #internal::InterfaceType,
497 bytes: &[u8],
498 ) -> #internal::anyhow::Result<Self> {
499 #extract_ty
500 debug_assert!(
501 (bytes.as_ptr() as usize)
502 % (<Self as #wt::component::ComponentType>::ALIGN32 as usize)
503 == 0
504 );
505 let mut offset = 0;
506 Ok(Self {
507 #loads
508 })
509 }
510 }
511 };
512
513 Ok(expanded)
514 }
515
516 fn expand_variant(
517 &self,
518 name: &syn::Ident,
519 generics: &syn::Generics,
520 discriminant_size: DiscriminantSize,
521 cases: &[VariantCase],
522 wt: &syn::Path,
523 ) -> Result<TokenStream> {
524 let internal = quote!(#wt::component::__internal);
525
526 let mut lifts = TokenStream::new();
527 let mut loads = TokenStream::new();
528
529 for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
530 let index_u32 = u32::try_from(index).unwrap();
531
532 let index_quoted = quote(discriminant_size, index);
533
534 if let Some(ty) = ty {
535 let payload_ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info));
536 lifts.extend(
537 quote!(#index_u32 => Self::#ident(<#ty as #wt::component::Lift>::lift(
538 cx, #payload_ty, unsafe { &src.payload.#ident }
539 )?),),
540 );
541
542 loads.extend(
543 quote!(#index_quoted => Self::#ident(<#ty as #wt::component::Lift>::load(
544 cx, #payload_ty, &payload[..<#ty as #wt::component::ComponentType>::SIZE32]
545 )?),),
546 );
547 } else {
548 lifts.extend(quote!(#index_u32 => Self::#ident,));
549
550 loads.extend(quote!(#index_quoted => Self::#ident,));
551 }
552 }
553
554 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lift));
555 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
556
557 let from_bytes = match discriminant_size {
558 DiscriminantSize::Size1 => quote!(bytes[0]),
559 DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
560 DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
561 };
562
563 let extract_ty = quote! {
564 let ty = match ty {
565 #internal::InterfaceType::Variant(i) => &cx.types[i],
566 _ => #internal::bad_type_info(),
567 };
568 };
569
570 let expanded = quote! {
571 unsafe impl #impl_generics #wt::component::Lift for #name #ty_generics #where_clause {
572 #[inline]
573 fn lift(
574 cx: &mut #internal::LiftContext<'_>,
575 ty: #internal::InterfaceType,
576 src: &Self::Lower,
577 ) -> #internal::anyhow::Result<Self> {
578 #extract_ty
579 Ok(match src.tag.get_u32() {
580 #lifts
581 discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
582 })
583 }
584
585 #[inline]
586 fn load(
587 cx: &mut #internal::LiftContext<'_>,
588 ty: #internal::InterfaceType,
589 bytes: &[u8],
590 ) -> #internal::anyhow::Result<Self> {
591 let align = <Self as #wt::component::ComponentType>::ALIGN32;
592 debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
593 let discrim = #from_bytes;
594 let payload_offset = <Self as #internal::ComponentVariant>::PAYLOAD_OFFSET32;
595 let payload = &bytes[payload_offset..];
596 #extract_ty
597 Ok(match discrim {
598 #loads
599 discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim),
600 })
601 }
602 }
603 };
604
605 Ok(expanded)
606 }
607
608 fn expand_enum(
609 &self,
610 name: &syn::Ident,
611 discriminant_size: DiscriminantSize,
612 cases: &[VariantCase],
613 wt: &syn::Path,
614 ) -> Result<TokenStream> {
615 let internal = quote!(#wt::component::__internal);
616
617 let (from_bytes, discrim_ty) = match discriminant_size {
618 DiscriminantSize::Size1 => (quote!(bytes[0]), quote!(u8)),
619 DiscriminantSize::Size2 => (
620 quote!(u16::from_le_bytes(bytes[0..2].try_into()?)),
621 quote!(u16),
622 ),
623 DiscriminantSize::Size4 => (
624 quote!(u32::from_le_bytes(bytes[0..4].try_into()?)),
625 quote!(u32),
626 ),
627 };
628 let discrim_limit = proc_macro2::Literal::usize_unsuffixed(cases.len());
629
630 let extract_ty = quote! {
631 let ty = match ty {
632 #internal::InterfaceType::Enum(i) => &cx.types[i],
633 _ => #internal::bad_type_info(),
634 };
635 };
636
637 let expanded = quote! {
638 unsafe impl #wt::component::Lift for #name {
639 #[inline]
640 fn lift(
641 cx: &mut #internal::LiftContext<'_>,
642 ty: #internal::InterfaceType,
643 src: &Self::Lower,
644 ) -> #internal::anyhow::Result<Self> {
645 #extract_ty
646 let discrim = src.tag.get_u32();
647 if discrim >= #discrim_limit {
648 #internal::anyhow::bail!("unexpected discriminant: {discrim}");
649 }
650 Ok(unsafe {
651 #internal::transmute::<#discrim_ty, #name>(discrim as #discrim_ty)
652 })
653 }
654
655 #[inline]
656 fn load(
657 cx: &mut #internal::LiftContext<'_>,
658 ty: #internal::InterfaceType,
659 bytes: &[u8],
660 ) -> #internal::anyhow::Result<Self> {
661 let align = <Self as #wt::component::ComponentType>::ALIGN32;
662 debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0);
663 let discrim = #from_bytes;
664 if discrim >= #discrim_limit {
665 #internal::anyhow::bail!("unexpected discriminant: {discrim}");
666 }
667 Ok(unsafe {
668 #internal::transmute::<#discrim_ty, #name>(discrim)
669 })
670 }
671 }
672 };
673
674 Ok(expanded)
675 }
676}
677
678pub struct LowerExpander;
679
680impl Expander for LowerExpander {
681 fn expand_record(
682 &self,
683 name: &syn::Ident,
684 generics: &syn::Generics,
685 fields: &[&syn::Field],
686 wt: &syn::Path,
687 ) -> Result<TokenStream> {
688 let internal = quote!(#wt::component::__internal);
689
690 let mut lowers = TokenStream::new();
691 let mut stores = TokenStream::new();
692
693 for (i, syn::Field { ident, ty, .. }) in fields.iter().enumerate() {
694 let field_ty = quote!(ty.fields[#i].ty);
695 lowers.extend(quote!(#wt::component::Lower::lower(
696 &self.#ident, cx, #field_ty, #internal::map_maybe_uninit!(dst.#ident)
697 )?;));
698
699 stores.extend(quote!(#wt::component::Lower::store(
700 &self.#ident,
701 cx,
702 #field_ty,
703 <#ty as #wt::component::ComponentType>::ABI.next_field32_size(&mut offset),
704 )?;));
705 }
706
707 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lower));
708 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
709
710 let extract_ty = quote! {
711 let ty = match ty {
712 #internal::InterfaceType::Record(i) => &cx.types[i],
713 _ => #internal::bad_type_info(),
714 };
715 };
716
717 let expanded = quote! {
718 unsafe impl #impl_generics #wt::component::Lower for #name #ty_generics #where_clause {
719 #[inline]
720 fn lower<T>(
721 &self,
722 cx: &mut #internal::LowerContext<'_, T>,
723 ty: #internal::InterfaceType,
724 dst: &mut core::mem::MaybeUninit<Self::Lower>,
725 ) -> #internal::anyhow::Result<()> {
726 #extract_ty
727 #lowers
728 Ok(())
729 }
730
731 #[inline]
732 fn store<T>(
733 &self,
734 cx: &mut #internal::LowerContext<'_, T>,
735 ty: #internal::InterfaceType,
736 mut offset: usize
737 ) -> #internal::anyhow::Result<()> {
738 debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);
739 #extract_ty
740 #stores
741 Ok(())
742 }
743 }
744 };
745
746 Ok(expanded)
747 }
748
749 fn expand_variant(
750 &self,
751 name: &syn::Ident,
752 generics: &syn::Generics,
753 discriminant_size: DiscriminantSize,
754 cases: &[VariantCase],
755 wt: &syn::Path,
756 ) -> Result<TokenStream> {
757 let internal = quote!(#wt::component::__internal);
758
759 let mut lowers = TokenStream::new();
760 let mut stores = TokenStream::new();
761
762 for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() {
763 let index_u32 = u32::try_from(index).unwrap();
764
765 let index_quoted = quote(discriminant_size, index);
766
767 let discriminant_size = usize::from(discriminant_size);
768
769 let pattern;
770 let lower;
771 let store;
772
773 if ty.is_some() {
774 let ty = quote!(ty.cases[#index].unwrap_or_else(#internal::bad_type_info));
775 pattern = quote!(Self::#ident(value));
776 lower = quote!(value.lower(cx, #ty, dst));
777 store = quote!(value.store(
778 cx,
779 #ty,
780 offset + <Self as #internal::ComponentVariant>::PAYLOAD_OFFSET32,
781 ));
782 } else {
783 pattern = quote!(Self::#ident);
784 lower = quote!(Ok(()));
785 store = quote!(Ok(()));
786 }
787
788 lowers.extend(quote!(#pattern => {
789 #internal::map_maybe_uninit!(dst.tag).write(#wt::ValRaw::u32(#index_u32));
790 unsafe {
791 #internal::lower_payload(
792 #internal::map_maybe_uninit!(dst.payload),
793 |payload| #internal::map_maybe_uninit!(payload.#ident),
794 |dst| #lower,
795 )
796 }
797 }));
798
799 stores.extend(quote!(#pattern => {
800 *cx.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes();
801 #store
802 }));
803 }
804
805 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::Lower));
806 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
807
808 let extract_ty = quote! {
809 let ty = match ty {
810 #internal::InterfaceType::Variant(i) => &cx.types[i],
811 _ => #internal::bad_type_info(),
812 };
813 };
814
815 let expanded = quote! {
816 unsafe impl #impl_generics #wt::component::Lower for #name #ty_generics #where_clause {
817 #[inline]
818 fn lower<T>(
819 &self,
820 cx: &mut #internal::LowerContext<'_, T>,
821 ty: #internal::InterfaceType,
822 dst: &mut core::mem::MaybeUninit<Self::Lower>,
823 ) -> #internal::anyhow::Result<()> {
824 #extract_ty
825 match self {
826 #lowers
827 }
828 }
829
830 #[inline]
831 fn store<T>(
832 &self,
833 cx: &mut #internal::LowerContext<'_, T>,
834 ty: #internal::InterfaceType,
835 mut offset: usize
836 ) -> #internal::anyhow::Result<()> {
837 #extract_ty
838 debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);
839 match self {
840 #stores
841 }
842 }
843 }
844 };
845
846 Ok(expanded)
847 }
848
849 fn expand_enum(
850 &self,
851 name: &syn::Ident,
852 discriminant_size: DiscriminantSize,
853 _cases: &[VariantCase],
854 wt: &syn::Path,
855 ) -> Result<TokenStream> {
856 let internal = quote!(#wt::component::__internal);
857
858 let extract_ty = quote! {
859 let ty = match ty {
860 #internal::InterfaceType::Enum(i) => &cx.types[i],
861 _ => #internal::bad_type_info(),
862 };
863 };
864
865 let (size, ty) = match discriminant_size {
866 DiscriminantSize::Size1 => (1, quote!(u8)),
867 DiscriminantSize::Size2 => (2, quote!(u16)),
868 DiscriminantSize::Size4 => (4, quote!(u32)),
869 };
870 let size = proc_macro2::Literal::usize_unsuffixed(size);
871
872 let expanded = quote! {
873 unsafe impl #wt::component::Lower for #name {
874 #[inline]
875 fn lower<T>(
876 &self,
877 cx: &mut #internal::LowerContext<'_, T>,
878 ty: #internal::InterfaceType,
879 dst: &mut core::mem::MaybeUninit<Self::Lower>,
880 ) -> #internal::anyhow::Result<()> {
881 #extract_ty
882 #internal::map_maybe_uninit!(dst.tag)
883 .write(#wt::ValRaw::u32(*self as u32));
884 Ok(())
885 }
886
887 #[inline]
888 fn store<T>(
889 &self,
890 cx: &mut #internal::LowerContext<'_, T>,
891 ty: #internal::InterfaceType,
892 mut offset: usize
893 ) -> #internal::anyhow::Result<()> {
894 #extract_ty
895 debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);
896 let discrim = *self as #ty;
897 *cx.get::<#size>(offset) = discrim.to_le_bytes();
898 Ok(())
899 }
900 }
901 };
902
903 Ok(expanded)
904 }
905}
906
907pub struct ComponentTypeExpander;
908
909impl Expander for ComponentTypeExpander {
910 fn expand_record(
911 &self,
912 name: &syn::Ident,
913 generics: &syn::Generics,
914 fields: &[&syn::Field],
915 wt: &syn::Path,
916 ) -> Result<TokenStream> {
917 expand_record_for_component_type(
918 name,
919 generics,
920 fields,
921 quote!(typecheck_record),
922 fields
923 .iter()
924 .map(
925 |syn::Field {
926 attrs, ident, ty, ..
927 }| {
928 let name = find_rename(attrs)?.unwrap_or_else(|| {
929 let ident = ident.as_ref().unwrap();
930 syn::LitStr::new(&ident.to_string(), ident.span())
931 });
932
933 Ok(quote!((#name, <#ty as #wt::component::ComponentType>::typecheck),))
934 },
935 )
936 .collect::<Result<_>>()?,
937 wt,
938 )
939 }
940
941 fn expand_variant(
942 &self,
943 name: &syn::Ident,
944 generics: &syn::Generics,
945 _discriminant_size: DiscriminantSize,
946 cases: &[VariantCase],
947 wt: &syn::Path,
948 ) -> Result<TokenStream> {
949 let internal = quote!(#wt::component::__internal);
950
951 let mut case_names_and_checks = TokenStream::new();
952 let mut lower_payload_generic_params = TokenStream::new();
953 let mut lower_payload_generic_args = TokenStream::new();
954 let mut lower_payload_case_declarations = TokenStream::new();
955 let mut lower_generic_args = TokenStream::new();
956 let mut abi_list = TokenStream::new();
957 let mut unique_types = HashSet::new();
958
959 for (index, VariantCase { attrs, ident, ty }) in cases.iter().enumerate() {
960 let rename = find_rename(attrs)?;
961
962 let name = rename.unwrap_or_else(|| syn::LitStr::new(&ident.to_string(), ident.span()));
963
964 if let Some(ty) = ty {
965 abi_list.extend(quote!(Some(<#ty as #wt::component::ComponentType>::ABI),));
966
967 case_names_and_checks.extend(
968 quote!((#name, Some(<#ty as #wt::component::ComponentType>::typecheck)),),
969 );
970
971 let generic = format_ident!("T{}", index);
972
973 lower_payload_generic_params.extend(quote!(#generic: Copy,));
974 lower_payload_generic_args.extend(quote!(#generic,));
975 lower_payload_case_declarations.extend(quote!(#ident: #generic,));
976 lower_generic_args.extend(quote!(<#ty as #wt::component::ComponentType>::Lower,));
977
978 unique_types.insert(ty);
979 } else {
980 abi_list.extend(quote!(None,));
981 case_names_and_checks.extend(quote!((#name, None),));
982 lower_payload_case_declarations.extend(quote!(#ident: [#wt::ValRaw; 0],));
983 }
984 }
985
986 let generics = add_trait_bounds(generics, parse_quote!(#wt::component::ComponentType));
987 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
988 let lower = format_ident!("Lower{}", name);
989 let lower_payload = format_ident!("LowerPayload{}", name);
990
991 let expanded = quote! {
1000 #[doc(hidden)]
1001 #[derive(Clone, Copy)]
1002 #[repr(C)]
1003 pub struct #lower<#lower_payload_generic_params> {
1004 tag: #wt::ValRaw,
1005 payload: #lower_payload<#lower_payload_generic_args>
1006 }
1007
1008 #[doc(hidden)]
1009 #[allow(non_snake_case)]
1010 #[derive(Clone, Copy)]
1011 #[repr(C)]
1012 union #lower_payload<#lower_payload_generic_params> {
1013 #lower_payload_case_declarations
1014 }
1015
1016 unsafe impl #impl_generics #wt::component::ComponentType for #name #ty_generics #where_clause {
1017 type Lower = #lower<#lower_generic_args>;
1018
1019 #[inline]
1020 fn typecheck(
1021 ty: &#internal::InterfaceType,
1022 types: &#internal::InstanceType<'_>,
1023 ) -> #internal::anyhow::Result<()> {
1024 #internal::typecheck_variant(ty, types, &[#case_names_and_checks])
1025 }
1026
1027 const ABI: #internal::CanonicalAbiInfo =
1028 #internal::CanonicalAbiInfo::variant_static(&[#abi_list]);
1029 }
1030
1031 unsafe impl #impl_generics #internal::ComponentVariant for #name #ty_generics #where_clause {
1032 const CASES: &'static [Option<#internal::CanonicalAbiInfo>] = &[#abi_list];
1033 }
1034 };
1035
1036 Ok(quote!(const _: () = { #expanded };))
1037 }
1038
1039 fn expand_enum(
1040 &self,
1041 name: &syn::Ident,
1042 _discriminant_size: DiscriminantSize,
1043 cases: &[VariantCase],
1044 wt: &syn::Path,
1045 ) -> Result<TokenStream> {
1046 let internal = quote!(#wt::component::__internal);
1047
1048 let mut case_names = TokenStream::new();
1049 let mut abi_list = TokenStream::new();
1050
1051 for VariantCase { attrs, ident, ty } in cases.iter() {
1052 let rename = find_rename(attrs)?;
1053
1054 let name = rename.unwrap_or_else(|| syn::LitStr::new(&ident.to_string(), ident.span()));
1055
1056 if ty.is_some() {
1057 return Err(Error::new(
1058 ident.span(),
1059 "payloads are not permitted for `enum` cases",
1060 ));
1061 }
1062 abi_list.extend(quote!(None,));
1063 case_names.extend(quote!(#name,));
1064 }
1065
1066 let lower = format_ident!("Lower{}", name);
1067
1068 let cases_len = cases.len();
1069 let expanded = quote! {
1070 #[doc(hidden)]
1071 #[derive(Clone, Copy)]
1072 #[repr(C)]
1073 pub struct #lower {
1074 tag: #wt::ValRaw,
1075 }
1076
1077 unsafe impl #wt::component::ComponentType for #name {
1078 type Lower = #lower;
1079
1080 #[inline]
1081 fn typecheck(
1082 ty: &#internal::InterfaceType,
1083 types: &#internal::InstanceType<'_>,
1084 ) -> #internal::anyhow::Result<()> {
1085 #internal::typecheck_enum(ty, types, &[#case_names])
1086 }
1087
1088 const ABI: #internal::CanonicalAbiInfo =
1089 #internal::CanonicalAbiInfo::enum_(#cases_len);
1090 }
1091
1092 unsafe impl #internal::ComponentVariant for #name {
1093 const CASES: &'static [Option<#internal::CanonicalAbiInfo>] = &[#abi_list];
1094 }
1095 };
1096
1097 Ok(quote!(const _: () = { #expanded };))
1098 }
1099}
1100
1101#[derive(Debug)]
1102struct Flag {
1103 rename: Option<String>,
1104 name: String,
1105}
1106
1107impl Parse for Flag {
1108 fn parse(input: ParseStream) -> Result<Self> {
1109 let attributes = syn::Attribute::parse_outer(input)?;
1110
1111 let rename = find_rename(&attributes)?.map(|literal| literal.value());
1112
1113 input.parse::<Token![const]>()?;
1114 let name = input.parse::<syn::Ident>()?.to_string();
1115
1116 Ok(Self { rename, name })
1117 }
1118}
1119
1120#[derive(Debug)]
1121pub struct Flags {
1122 name: String,
1123 flags: Vec<Flag>,
1124}
1125
1126impl Parse for Flags {
1127 fn parse(input: ParseStream) -> Result<Self> {
1128 let name = input.parse::<syn::Ident>()?.to_string();
1129
1130 let content;
1131 braced!(content in input);
1132
1133 let flags = content
1134 .parse_terminated(Flag::parse, Token![;])?
1135 .into_iter()
1136 .collect();
1137
1138 Ok(Self { name, flags })
1139 }
1140}
1141
1142pub fn expand_flags(flags: &Flags) -> Result<TokenStream> {
1143 let wt = default_wasmtime_crate();
1144 let size = FlagsSize::from_count(flags.flags.len());
1145
1146 let ty;
1147 let eq;
1148
1149 let count = flags.flags.len();
1150
1151 match size {
1152 FlagsSize::Size0 => {
1153 ty = quote!(());
1154 eq = quote!(true);
1155 }
1156 FlagsSize::Size1 => {
1157 ty = quote!(u8);
1158
1159 eq = if count == 8 {
1160 quote!(self.__inner0.eq(&rhs.__inner0))
1161 } else {
1162 let mask = !(0xFF_u8 << count);
1163
1164 quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))
1165 };
1166 }
1167 FlagsSize::Size2 => {
1168 ty = quote!(u16);
1169
1170 eq = if count == 16 {
1171 quote!(self.__inner0.eq(&rhs.__inner0))
1172 } else {
1173 let mask = !(0xFFFF_u16 << count);
1174
1175 quote!((self.__inner0 & #mask).eq(&(rhs.__inner0 & #mask)))
1176 };
1177 }
1178 FlagsSize::Size4Plus(n) => {
1179 ty = quote!(u32);
1180
1181 let comparisons = (0..(n - 1))
1182 .map(|index| {
1183 let field = format_ident!("__inner{}", index);
1184
1185 quote!(self.#field.eq(&rhs.#field) &&)
1186 })
1187 .collect::<TokenStream>();
1188
1189 let field = format_ident!("__inner{}", n - 1);
1190
1191 eq = if count % 32 == 0 {
1192 quote!(#comparisons self.#field.eq(&rhs.#field))
1193 } else {
1194 let mask = !(0xFFFF_FFFF_u32 << (count % 32));
1195
1196 quote!(#comparisons (self.#field & #mask).eq(&(rhs.#field & #mask)))
1197 }
1198 }
1199 }
1200
1201 let count;
1202 let mut as_array;
1203 let mut bitor;
1204 let mut bitor_assign;
1205 let mut bitand;
1206 let mut bitand_assign;
1207 let mut bitxor;
1208 let mut bitxor_assign;
1209 let mut not;
1210
1211 match size {
1212 FlagsSize::Size0 => {
1213 count = 0;
1214 as_array = quote!([]);
1215 bitor = quote!(Self {});
1216 bitor_assign = quote!();
1217 bitand = quote!(Self {});
1218 bitand_assign = quote!();
1219 bitxor = quote!(Self {});
1220 bitxor_assign = quote!();
1221 not = quote!(Self {});
1222 }
1223 FlagsSize::Size1 | FlagsSize::Size2 => {
1224 count = 1;
1225 as_array = quote!([self.__inner0 as u32]);
1226 bitor = quote!(Self {
1227 __inner0: self.__inner0.bitor(rhs.__inner0)
1228 });
1229 bitor_assign = quote!(self.__inner0.bitor_assign(rhs.__inner0));
1230 bitand = quote!(Self {
1231 __inner0: self.__inner0.bitand(rhs.__inner0)
1232 });
1233 bitand_assign = quote!(self.__inner0.bitand_assign(rhs.__inner0));
1234 bitxor = quote!(Self {
1235 __inner0: self.__inner0.bitxor(rhs.__inner0)
1236 });
1237 bitxor_assign = quote!(self.__inner0.bitxor_assign(rhs.__inner0));
1238 not = quote!(Self {
1239 __inner0: self.__inner0.not()
1240 });
1241 }
1242 FlagsSize::Size4Plus(n) => {
1243 count = usize::from(n);
1244 as_array = TokenStream::new();
1245 bitor = TokenStream::new();
1246 bitor_assign = TokenStream::new();
1247 bitand = TokenStream::new();
1248 bitand_assign = TokenStream::new();
1249 bitxor = TokenStream::new();
1250 bitxor_assign = TokenStream::new();
1251 not = TokenStream::new();
1252
1253 for index in 0..n {
1254 let field = format_ident!("__inner{}", index);
1255
1256 as_array.extend(quote!(self.#field,));
1257 bitor.extend(quote!(#field: self.#field.bitor(rhs.#field),));
1258 bitor_assign.extend(quote!(self.#field.bitor_assign(rhs.#field);));
1259 bitand.extend(quote!(#field: self.#field.bitand(rhs.#field),));
1260 bitand_assign.extend(quote!(self.#field.bitand_assign(rhs.#field);));
1261 bitxor.extend(quote!(#field: self.#field.bitxor(rhs.#field),));
1262 bitxor_assign.extend(quote!(self.#field.bitxor_assign(rhs.#field);));
1263 not.extend(quote!(#field: self.#field.not(),));
1264 }
1265
1266 as_array = quote!([#as_array]);
1267 bitor = quote!(Self { #bitor });
1268 bitand = quote!(Self { #bitand });
1269 bitxor = quote!(Self { #bitxor });
1270 not = quote!(Self { #not });
1271 }
1272 };
1273
1274 let name = format_ident!("{}", flags.name);
1275
1276 let mut constants = TokenStream::new();
1277 let mut rust_names = TokenStream::new();
1278 let mut component_names = TokenStream::new();
1279
1280 for (index, Flag { name, rename }) in flags.flags.iter().enumerate() {
1281 rust_names.extend(quote!(#name,));
1282
1283 let component_name = rename.as_ref().unwrap_or(name);
1284 component_names.extend(quote!(#component_name,));
1285
1286 let fields = match size {
1287 FlagsSize::Size0 => quote!(),
1288 FlagsSize::Size1 => {
1289 let init = 1_u8 << index;
1290 quote!(__inner0: #init)
1291 }
1292 FlagsSize::Size2 => {
1293 let init = 1_u16 << index;
1294 quote!(__inner0: #init)
1295 }
1296 FlagsSize::Size4Plus(n) => (0..n)
1297 .map(|i| {
1298 let field = format_ident!("__inner{}", i);
1299
1300 let init = if index / 32 == usize::from(i) {
1301 1_u32 << (index % 32)
1302 } else {
1303 0
1304 };
1305
1306 quote!(#field: #init,)
1307 })
1308 .collect::<TokenStream>(),
1309 };
1310
1311 let name = format_ident!("{}", name);
1312
1313 constants.extend(quote!(pub const #name: Self = Self { #fields };));
1314 }
1315
1316 let generics = syn::Generics {
1317 lt_token: None,
1318 params: Punctuated::new(),
1319 gt_token: None,
1320 where_clause: None,
1321 };
1322
1323 let fields = {
1324 let ty = syn::parse2::<syn::Type>(ty.clone())?;
1325
1326 (0..count)
1327 .map(|index| syn::Field {
1328 attrs: Vec::new(),
1329 vis: syn::Visibility::Inherited,
1330 ident: Some(format_ident!("__inner{}", index)),
1331 colon_token: None,
1332 ty: ty.clone(),
1333 mutability: syn::FieldMutability::None,
1334 })
1335 .collect::<Vec<_>>()
1336 };
1337
1338 let fields = fields.iter().collect::<Vec<_>>();
1339
1340 let component_type_impl = expand_record_for_component_type(
1341 &name,
1342 &generics,
1343 &fields,
1344 quote!(typecheck_flags),
1345 component_names,
1346 &wt,
1347 )?;
1348
1349 let internal = quote!(#wt::component::__internal);
1350
1351 let field_names = fields
1352 .iter()
1353 .map(|syn::Field { ident, .. }| ident)
1354 .collect::<Vec<_>>();
1355
1356 let fields = fields
1357 .iter()
1358 .map(|syn::Field { ident, .. }| quote!(#[doc(hidden)] #ident: #ty,))
1359 .collect::<TokenStream>();
1360
1361 let (field_interface_type, field_size) = match size {
1362 FlagsSize::Size0 => (quote!(NOT USED), 0usize),
1363 FlagsSize::Size1 => (quote!(#internal::InterfaceType::U8), 1),
1364 FlagsSize::Size2 => (quote!(#internal::InterfaceType::U16), 2),
1365 FlagsSize::Size4Plus(_) => (quote!(#internal::InterfaceType::U32), 4),
1366 };
1367
1368 let expanded = quote! {
1369 #[derive(Copy, Clone, Default)]
1370 pub struct #name { #fields }
1371
1372 impl #name {
1373 #constants
1374
1375 pub fn as_array(&self) -> [u32; #count] {
1376 #as_array
1377 }
1378
1379 pub fn empty() -> Self {
1380 Self::default()
1381 }
1382
1383 pub fn all() -> Self {
1384 use core::ops::Not;
1385 Self::default().not()
1386 }
1387
1388 pub fn contains(&self, other: Self) -> bool {
1389 *self & other == other
1390 }
1391
1392 pub fn intersects(&self, other: Self) -> bool {
1393 *self & other != Self::empty()
1394 }
1395 }
1396
1397 impl core::cmp::PartialEq for #name {
1398 fn eq(&self, rhs: &#name) -> bool {
1399 #eq
1400 }
1401 }
1402
1403 impl core::cmp::Eq for #name { }
1404
1405 impl core::fmt::Debug for #name {
1406 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
1407 #internal::format_flags(&self.as_array(), &[#rust_names], f)
1408 }
1409 }
1410
1411 impl core::ops::BitOr for #name {
1412 type Output = #name;
1413
1414 fn bitor(self, rhs: #name) -> #name {
1415 #bitor
1416 }
1417 }
1418
1419 impl core::ops::BitOrAssign for #name {
1420 fn bitor_assign(&mut self, rhs: #name) {
1421 #bitor_assign
1422 }
1423 }
1424
1425 impl core::ops::BitAnd for #name {
1426 type Output = #name;
1427
1428 fn bitand(self, rhs: #name) -> #name {
1429 #bitand
1430 }
1431 }
1432
1433 impl core::ops::BitAndAssign for #name {
1434 fn bitand_assign(&mut self, rhs: #name) {
1435 #bitand_assign
1436 }
1437 }
1438
1439 impl core::ops::BitXor for #name {
1440 type Output = #name;
1441
1442 fn bitxor(self, rhs: #name) -> #name {
1443 #bitxor
1444 }
1445 }
1446
1447 impl core::ops::BitXorAssign for #name {
1448 fn bitxor_assign(&mut self, rhs: #name) {
1449 #bitxor_assign
1450 }
1451 }
1452
1453 impl core::ops::Not for #name {
1454 type Output = #name;
1455
1456 fn not(self) -> #name {
1457 #not
1458 }
1459 }
1460
1461 #component_type_impl
1462
1463 unsafe impl #wt::component::Lower for #name {
1464 fn lower<T>(
1465 &self,
1466 cx: &mut #internal::LowerContext<'_, T>,
1467 _ty: #internal::InterfaceType,
1468 dst: &mut core::mem::MaybeUninit<Self::Lower>,
1469 ) -> #internal::anyhow::Result<()> {
1470 #(
1471 self.#field_names.lower(
1472 cx,
1473 #field_interface_type,
1474 #internal::map_maybe_uninit!(dst.#field_names),
1475 )?;
1476 )*
1477 Ok(())
1478 }
1479
1480 fn store<T>(
1481 &self,
1482 cx: &mut #internal::LowerContext<'_, T>,
1483 _ty: #internal::InterfaceType,
1484 mut offset: usize
1485 ) -> #internal::anyhow::Result<()> {
1486 debug_assert!(offset % (<Self as #wt::component::ComponentType>::ALIGN32 as usize) == 0);
1487 #(
1488 self.#field_names.store(
1489 cx,
1490 #field_interface_type,
1491 offset,
1492 )?;
1493 offset += core::mem::size_of_val(&self.#field_names);
1494 )*
1495 Ok(())
1496 }
1497 }
1498
1499 unsafe impl #wt::component::Lift for #name {
1500 fn lift(
1501 cx: &mut #internal::LiftContext<'_>,
1502 _ty: #internal::InterfaceType,
1503 src: &Self::Lower,
1504 ) -> #internal::anyhow::Result<Self> {
1505 Ok(Self {
1506 #(
1507 #field_names: #wt::component::Lift::lift(
1508 cx,
1509 #field_interface_type,
1510 &src.#field_names,
1511 )?,
1512 )*
1513 })
1514 }
1515
1516 fn load(
1517 cx: &mut #internal::LiftContext<'_>,
1518 _ty: #internal::InterfaceType,
1519 bytes: &[u8],
1520 ) -> #internal::anyhow::Result<Self> {
1521 debug_assert!(
1522 (bytes.as_ptr() as usize)
1523 % (<Self as #wt::component::ComponentType>::ALIGN32 as usize)
1524 == 0
1525 );
1526 #(
1527 let (field, bytes) = bytes.split_at(#field_size);
1528 let #field_names = #wt::component::Lift::load(
1529 cx,
1530 #field_interface_type,
1531 field,
1532 )?;
1533 )*
1534 Ok(Self { #(#field_names,)* })
1535 }
1536 }
1537 };
1538
1539 Ok(expanded)
1540}