component_macro_test/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::{parse_macro_input, Error, Result, Token};
5
6#[proc_macro_attribute]
7pub fn add_variants(
8    attr: proc_macro::TokenStream,
9    item: proc_macro::TokenStream,
10) -> proc_macro::TokenStream {
11    expand_variants(
12        &parse_macro_input!(attr as syn::LitInt),
13        parse_macro_input!(item as syn::ItemEnum),
14    )
15    .unwrap_or_else(syn::Error::into_compile_error)
16    .into()
17}
18
19fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result<TokenStream> {
20    let count = count
21        .base10_digits()
22        .parse::<usize>()
23        .map_err(|_| syn::Error::new(count.span(), "expected unsigned integer"))?;
24
25    ty.variants = (0..count)
26        .map(|index| syn::Variant {
27            attrs: Vec::new(),
28            ident: syn::Ident::new(&format!("V{}", index), Span::call_site()),
29            fields: syn::Fields::Unit,
30            discriminant: None,
31        })
32        .collect();
33
34    Ok(quote!(#ty))
35}
36
37#[derive(Debug)]
38struct FlagsTest {
39    name: String,
40    flag_count: usize,
41}
42
43impl Parse for FlagsTest {
44    fn parse(input: ParseStream) -> Result<Self> {
45        let name = input.parse::<syn::Ident>()?.to_string();
46        input.parse::<Token![,]>()?;
47        let flag_count = input.parse::<syn::LitInt>()?.base10_parse()?;
48
49        Ok(Self { name, flag_count })
50    }
51}
52
53#[proc_macro]
54pub fn flags_test(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
55    expand_flags_test(&parse_macro_input!(input as FlagsTest))
56        .unwrap_or_else(Error::into_compile_error)
57        .into()
58}
59
60fn expand_flags_test(test: &FlagsTest) -> Result<TokenStream> {
61    let name = format_ident!("{}", test.name);
62    let flags = (0..test.flag_count)
63        .map(|index| {
64            let name = format_ident!("F{}", index);
65            quote!(const #name;)
66        })
67        .collect::<TokenStream>();
68
69    let expanded = quote! {
70        wasmtime::component::flags! {
71            #name {
72                #flags
73            }
74        }
75    };
76
77    Ok(expanded)
78}