wasmtime_component_macro/
bindgen.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use std::collections::{HashMap, HashSet};
4use std::env;
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
7use syn::parse::{Error, Parse, ParseStream, Result};
8use syn::punctuated::Punctuated;
9use syn::{braced, token, Token};
10use wasmtime_wit_bindgen::{
11    AsyncConfig, CallStyle, Opts, Ownership, TrappableError, TrappableImports,
12};
13use wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
14
15pub struct Config {
16    opts: Opts,
17    resolve: Resolve,
18    world: WorldId,
19    files: Vec<PathBuf>,
20    include_generated_code_from_file: bool,
21}
22
23pub fn expand(input: &Config) -> Result<TokenStream> {
24    if let (CallStyle::Async | CallStyle::Concurrent, false) =
25        (input.opts.call_style(), cfg!(feature = "async"))
26    {
27        return Err(Error::new(
28            Span::call_site(),
29            "cannot enable async bindings unless `async` crate feature is active",
30        ));
31    }
32
33    if input.opts.concurrent_imports && !cfg!(feature = "component-model-async") {
34        return Err(Error::new(
35            Span::call_site(),
36            "cannot enable `concurrent_imports` option unless `component-model-async` crate feature is active",
37        ));
38    }
39
40    let mut src = match input.opts.generate(&input.resolve, input.world) {
41        Ok(s) => s,
42        Err(e) => return Err(Error::new(Span::call_site(), e.to_string())),
43    };
44
45    if input.opts.stringify {
46        return Ok(quote::quote!(#src));
47    }
48
49    // If a magical `WASMTIME_DEBUG_BINDGEN` environment variable is set then
50    // place a formatted version of the expanded code into a file. This file
51    // will then show up in rustc error messages for any codegen issues and can
52    // be inspected manually.
53    if input.include_generated_code_from_file
54        || input.opts.debug
55        || std::env::var("WASMTIME_DEBUG_BINDGEN").is_ok()
56    {
57        static INVOCATION: AtomicUsize = AtomicUsize::new(0);
58        let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
59        let world_name = &input.resolve.worlds[input.world].name;
60        let n = INVOCATION.fetch_add(1, Relaxed);
61        let path = root.join(format!("{world_name}{n}.rs"));
62
63        std::fs::write(&path, &src).unwrap();
64
65        // optimistically format the code but don't require success
66        drop(
67            std::process::Command::new("rustfmt")
68                .arg(&path)
69                .arg("--edition=2021")
70                .output(),
71        );
72
73        src = format!("include!({path:?});");
74    }
75    let mut contents = src.parse::<TokenStream>().unwrap();
76
77    // Include a dummy `include_str!` for any files we read so rustc knows that
78    // we depend on the contents of those files.
79    for file in input.files.iter() {
80        contents.extend(
81            format!("const _: &str = include_str!(r#\"{}\"#);\n", file.display())
82                .parse::<TokenStream>()
83                .unwrap(),
84        );
85    }
86
87    Ok(contents)
88}
89
90impl Parse for Config {
91    fn parse(input: ParseStream<'_>) -> Result<Self> {
92        let call_site = Span::call_site();
93        let mut opts = Opts::default();
94        let mut world = None;
95        let mut inline = None;
96        let mut paths = Vec::new();
97        let mut async_configured = false;
98        let mut include_generated_code_from_file = false;
99
100        if input.peek(token::Brace) {
101            let content;
102            syn::braced!(content in input);
103            let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
104            for field in fields.into_pairs() {
105                match field.into_value() {
106                    Opt::Path(p) => {
107                        paths.extend(p.into_iter().map(|p| p.value()));
108                    }
109                    Opt::World(s) => {
110                        if world.is_some() {
111                            return Err(Error::new(s.span(), "cannot specify second world"));
112                        }
113                        world = Some(s.value());
114                    }
115                    Opt::Inline(s) => {
116                        if inline.is_some() {
117                            return Err(Error::new(s.span(), "cannot specify second source"));
118                        }
119                        inline = Some(s.value());
120                    }
121                    Opt::Tracing(val) => opts.tracing = val,
122                    Opt::VerboseTracing(val) => opts.verbose_tracing = val,
123                    Opt::Debug(val) => opts.debug = val,
124                    Opt::Async(val, span) => {
125                        if async_configured {
126                            return Err(Error::new(span, "cannot specify second async config"));
127                        }
128                        async_configured = true;
129                        opts.async_ = val;
130                    }
131                    Opt::ConcurrentImports(val) => opts.concurrent_imports = val,
132                    Opt::ConcurrentExports(val) => opts.concurrent_exports = val,
133                    Opt::TrappableErrorType(val) => opts.trappable_error_type = val,
134                    Opt::TrappableImports(val) => opts.trappable_imports = val,
135                    Opt::Ownership(val) => opts.ownership = val,
136                    Opt::Interfaces(s) => {
137                        if inline.is_some() {
138                            return Err(Error::new(s.span(), "cannot specify a second source"));
139                        }
140                        inline = Some(format!(
141                            "
142                                package wasmtime:component-macro-synthesized;
143
144                                world interfaces {{
145                                    {}
146                                }}
147                            ",
148                            s.value()
149                        ));
150
151                        if world.is_some() {
152                            return Err(Error::new(
153                                s.span(),
154                                "cannot specify a world with `interfaces`",
155                            ));
156                        }
157                        world = Some("wasmtime:component-macro-synthesized/interfaces".to_string());
158
159                        opts.only_interfaces = true;
160                    }
161                    Opt::With(val) => opts.with.extend(val),
162                    Opt::AdditionalDerives(paths) => {
163                        opts.additional_derive_attributes = paths
164                            .into_iter()
165                            .map(|p| p.into_token_stream().to_string())
166                            .collect()
167                    }
168                    Opt::Stringify(val) => opts.stringify = val,
169                    Opt::SkipMutForwardingImpls(val) => opts.skip_mut_forwarding_impls = val,
170                    Opt::RequireStoreDataSend(val) => opts.require_store_data_send = val,
171                    Opt::WasmtimeCrate(f) => {
172                        opts.wasmtime_crate = Some(f.into_token_stream().to_string())
173                    }
174                    Opt::IncludeGeneratedCodeFromFile(i) => include_generated_code_from_file = i,
175                }
176            }
177        } else {
178            world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
179            if input.parse::<Option<syn::token::In>>()?.is_some() {
180                paths.push(input.parse::<syn::LitStr>()?.value());
181            }
182        }
183        let (resolve, pkgs, files) = parse_source(&paths, &inline)
184            .map_err(|err| Error::new(call_site, format!("{err:?}")))?;
185
186        let world = select_world(&resolve, &pkgs, world.as_deref())
187            .map_err(|e| Error::new(call_site, format!("{e:?}")))?;
188        Ok(Config {
189            opts,
190            resolve,
191            world,
192            files,
193            include_generated_code_from_file,
194        })
195    }
196}
197
198fn parse_source(
199    paths: &Vec<String>,
200    inline: &Option<String>,
201) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
202    let mut resolve = Resolve::default();
203    resolve.all_features = true;
204    let mut files = Vec::new();
205    let mut pkgs = Vec::new();
206    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
207
208    let parse = |resolve: &mut Resolve,
209                 files: &mut Vec<PathBuf>,
210                 pkgs: &mut Vec<PackageId>,
211                 paths: &[String]|
212     -> anyhow::Result<_> {
213        for path in paths {
214            let p = root.join(path);
215            // Try to normalize the path to make the error message more understandable when
216            // the path is not correct. Fallback to the original path if normalization fails
217            // (probably return an error somewhere else).
218            let normalized_path = match std::fs::canonicalize(&p) {
219                Ok(p) => p,
220                Err(_) => p.to_path_buf(),
221            };
222            let (pkg, sources) = resolve.push_path(normalized_path)?;
223            pkgs.push(pkg);
224            files.extend(sources.paths().map(|p| p.to_owned()));
225        }
226        Ok(())
227    };
228
229    if !paths.is_empty() {
230        parse(&mut resolve, &mut files, &mut pkgs, &paths)?;
231    }
232
233    if let Some(inline) = inline {
234        pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", inline)?)?);
235    }
236
237    if pkgs.is_empty() {
238        parse(&mut resolve, &mut files, &mut pkgs, &["wit".into()])?;
239    }
240
241    Ok((resolve, pkgs, files))
242}
243
244fn select_world(
245    resolve: &Resolve,
246    pkgs: &[PackageId],
247    world: Option<&str>,
248) -> anyhow::Result<WorldId> {
249    if pkgs.len() == 1 {
250        resolve.select_world(pkgs[0], world)
251    } else {
252        assert!(!pkgs.is_empty());
253        match world {
254            Some(name) => {
255                if !name.contains(":") {
256                    anyhow::bail!(
257                        "with multiple packages a fully qualified \
258                         world name must be specified"
259                    )
260                }
261
262                // This will ignore the package argument due to the fully
263                // qualified name being used.
264                resolve.select_world(pkgs[0], world)
265            }
266            None => {
267                let worlds = pkgs
268                    .iter()
269                    .filter_map(|p| resolve.select_world(*p, None).ok())
270                    .collect::<Vec<_>>();
271                match &worlds[..] {
272                    [] => anyhow::bail!("no packages have a world"),
273                    [world] => Ok(*world),
274                    _ => anyhow::bail!("multiple packages have a world, must specify which to use"),
275                }
276            }
277        }
278    }
279}
280
281mod kw {
282    syn::custom_keyword!(inline);
283    syn::custom_keyword!(path);
284    syn::custom_keyword!(tracing);
285    syn::custom_keyword!(verbose_tracing);
286    syn::custom_keyword!(trappable_error_type);
287    syn::custom_keyword!(world);
288    syn::custom_keyword!(ownership);
289    syn::custom_keyword!(interfaces);
290    syn::custom_keyword!(with);
291    syn::custom_keyword!(except_imports);
292    syn::custom_keyword!(only_imports);
293    syn::custom_keyword!(trappable_imports);
294    syn::custom_keyword!(additional_derives);
295    syn::custom_keyword!(stringify);
296    syn::custom_keyword!(skip_mut_forwarding_impls);
297    syn::custom_keyword!(require_store_data_send);
298    syn::custom_keyword!(wasmtime_crate);
299    syn::custom_keyword!(include_generated_code_from_file);
300    syn::custom_keyword!(concurrent_imports);
301    syn::custom_keyword!(concurrent_exports);
302    syn::custom_keyword!(debug);
303}
304
305enum Opt {
306    World(syn::LitStr),
307    Path(Vec<syn::LitStr>),
308    Inline(syn::LitStr),
309    Tracing(bool),
310    VerboseTracing(bool),
311    Async(AsyncConfig, Span),
312    TrappableErrorType(Vec<TrappableError>),
313    Ownership(Ownership),
314    Interfaces(syn::LitStr),
315    With(HashMap<String, String>),
316    TrappableImports(TrappableImports),
317    AdditionalDerives(Vec<syn::Path>),
318    Stringify(bool),
319    SkipMutForwardingImpls(bool),
320    RequireStoreDataSend(bool),
321    WasmtimeCrate(syn::Path),
322    IncludeGeneratedCodeFromFile(bool),
323    ConcurrentImports(bool),
324    ConcurrentExports(bool),
325    Debug(bool),
326}
327
328impl Parse for Opt {
329    fn parse(input: ParseStream<'_>) -> Result<Self> {
330        let l = input.lookahead1();
331        if l.peek(kw::debug) {
332            input.parse::<kw::debug>()?;
333            input.parse::<Token![:]>()?;
334            Ok(Opt::Debug(input.parse::<syn::LitBool>()?.value))
335        } else if l.peek(kw::path) {
336            input.parse::<kw::path>()?;
337            input.parse::<Token![:]>()?;
338
339            let mut paths: Vec<syn::LitStr> = vec![];
340
341            let l = input.lookahead1();
342            if l.peek(syn::LitStr) {
343                paths.push(input.parse()?);
344            } else if l.peek(syn::token::Bracket) {
345                let contents;
346                syn::bracketed!(contents in input);
347                let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
348
349                paths.extend(list.into_iter());
350            } else {
351                return Err(l.error());
352            };
353
354            Ok(Opt::Path(paths))
355        } else if l.peek(kw::inline) {
356            input.parse::<kw::inline>()?;
357            input.parse::<Token![:]>()?;
358            Ok(Opt::Inline(input.parse()?))
359        } else if l.peek(kw::world) {
360            input.parse::<kw::world>()?;
361            input.parse::<Token![:]>()?;
362            Ok(Opt::World(input.parse()?))
363        } else if l.peek(kw::tracing) {
364            input.parse::<kw::tracing>()?;
365            input.parse::<Token![:]>()?;
366            Ok(Opt::Tracing(input.parse::<syn::LitBool>()?.value))
367        } else if l.peek(kw::verbose_tracing) {
368            input.parse::<kw::verbose_tracing>()?;
369            input.parse::<Token![:]>()?;
370            Ok(Opt::VerboseTracing(input.parse::<syn::LitBool>()?.value))
371        } else if l.peek(Token![async]) {
372            let span = input.parse::<Token![async]>()?.span;
373            input.parse::<Token![:]>()?;
374            if input.peek(syn::LitBool) {
375                match input.parse::<syn::LitBool>()?.value {
376                    true => Ok(Opt::Async(AsyncConfig::All, span)),
377                    false => Ok(Opt::Async(AsyncConfig::None, span)),
378                }
379            } else {
380                let contents;
381                syn::braced!(contents in input);
382
383                let l = contents.lookahead1();
384                let ctor: fn(HashSet<String>) -> AsyncConfig = if l.peek(kw::except_imports) {
385                    contents.parse::<kw::except_imports>()?;
386                    contents.parse::<Token![:]>()?;
387                    AsyncConfig::AllExceptImports
388                } else if l.peek(kw::only_imports) {
389                    contents.parse::<kw::only_imports>()?;
390                    contents.parse::<Token![:]>()?;
391                    AsyncConfig::OnlyImports
392                } else {
393                    return Err(l.error());
394                };
395
396                let list;
397                syn::bracketed!(list in contents);
398                let fields: Punctuated<syn::LitStr, Token![,]> =
399                    list.parse_terminated(Parse::parse, Token![,])?;
400
401                if contents.peek(Token![,]) {
402                    contents.parse::<Token![,]>()?;
403                }
404                Ok(Opt::Async(
405                    ctor(fields.iter().map(|s| s.value()).collect()),
406                    span,
407                ))
408            }
409        } else if l.peek(kw::concurrent_imports) {
410            input.parse::<kw::concurrent_imports>()?;
411            input.parse::<Token![:]>()?;
412            Ok(Opt::ConcurrentImports(input.parse::<syn::LitBool>()?.value))
413        } else if l.peek(kw::concurrent_exports) {
414            input.parse::<kw::concurrent_exports>()?;
415            input.parse::<Token![:]>()?;
416            Ok(Opt::ConcurrentExports(input.parse::<syn::LitBool>()?.value))
417        } else if l.peek(kw::ownership) {
418            input.parse::<kw::ownership>()?;
419            input.parse::<Token![:]>()?;
420            let ownership = input.parse::<syn::Ident>()?;
421            Ok(Opt::Ownership(match ownership.to_string().as_str() {
422                "Owning" => Ownership::Owning,
423                "Borrowing" => Ownership::Borrowing {
424                    duplicate_if_necessary: {
425                        let contents;
426                        braced!(contents in input);
427                        let field = contents.parse::<syn::Ident>()?;
428                        match field.to_string().as_str() {
429                            "duplicate_if_necessary" => {
430                                contents.parse::<Token![:]>()?;
431                                contents.parse::<syn::LitBool>()?.value
432                            }
433                            name => {
434                                return Err(Error::new(
435                                    field.span(),
436                                    format!(
437                                        "unrecognized `Ownership::Borrowing` field: `{name}`; \
438                                         expected `duplicate_if_necessary`"
439                                    ),
440                                ));
441                            }
442                        }
443                    },
444                },
445                name => {
446                    return Err(Error::new(
447                        ownership.span(),
448                        format!(
449                            "unrecognized ownership: `{name}`; \
450                             expected `Owning` or `Borrowing`"
451                        ),
452                    ));
453                }
454            }))
455        } else if l.peek(kw::trappable_error_type) {
456            input.parse::<kw::trappable_error_type>()?;
457            input.parse::<Token![:]>()?;
458            let contents;
459            let _lbrace = braced!(contents in input);
460            let fields: Punctuated<_, Token![,]> =
461                contents.parse_terminated(trappable_error_field_parse, Token![,])?;
462            Ok(Opt::TrappableErrorType(Vec::from_iter(fields)))
463        } else if l.peek(kw::interfaces) {
464            input.parse::<kw::interfaces>()?;
465            input.parse::<Token![:]>()?;
466            Ok(Opt::Interfaces(input.parse::<syn::LitStr>()?))
467        } else if l.peek(kw::with) {
468            input.parse::<kw::with>()?;
469            input.parse::<Token![:]>()?;
470            let contents;
471            let _lbrace = braced!(contents in input);
472            let fields: Punctuated<(String, String), Token![,]> =
473                contents.parse_terminated(with_field_parse, Token![,])?;
474            Ok(Opt::With(HashMap::from_iter(fields)))
475        } else if l.peek(kw::trappable_imports) {
476            input.parse::<kw::trappable_imports>()?;
477            input.parse::<Token![:]>()?;
478            let config = if input.peek(syn::LitBool) {
479                match input.parse::<syn::LitBool>()?.value {
480                    true => TrappableImports::All,
481                    false => TrappableImports::None,
482                }
483            } else {
484                let contents;
485                syn::bracketed!(contents in input);
486                let fields: Punctuated<syn::LitStr, Token![,]> =
487                    contents.parse_terminated(Parse::parse, Token![,])?;
488                TrappableImports::Only(fields.iter().map(|s| s.value()).collect())
489            };
490            Ok(Opt::TrappableImports(config))
491        } else if l.peek(kw::additional_derives) {
492            input.parse::<kw::additional_derives>()?;
493            input.parse::<Token![:]>()?;
494            let contents;
495            syn::bracketed!(contents in input);
496            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
497            Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
498        } else if l.peek(kw::stringify) {
499            input.parse::<kw::stringify>()?;
500            input.parse::<Token![:]>()?;
501            Ok(Opt::Stringify(input.parse::<syn::LitBool>()?.value))
502        } else if l.peek(kw::skip_mut_forwarding_impls) {
503            input.parse::<kw::skip_mut_forwarding_impls>()?;
504            input.parse::<Token![:]>()?;
505            Ok(Opt::SkipMutForwardingImpls(
506                input.parse::<syn::LitBool>()?.value,
507            ))
508        } else if l.peek(kw::require_store_data_send) {
509            input.parse::<kw::require_store_data_send>()?;
510            input.parse::<Token![:]>()?;
511            Ok(Opt::RequireStoreDataSend(
512                input.parse::<syn::LitBool>()?.value,
513            ))
514        } else if l.peek(kw::wasmtime_crate) {
515            input.parse::<kw::wasmtime_crate>()?;
516            input.parse::<Token![:]>()?;
517            Ok(Opt::WasmtimeCrate(input.parse()?))
518        } else if l.peek(kw::include_generated_code_from_file) {
519            input.parse::<kw::include_generated_code_from_file>()?;
520            input.parse::<Token![:]>()?;
521            Ok(Opt::IncludeGeneratedCodeFromFile(
522                input.parse::<syn::LitBool>()?.value,
523            ))
524        } else {
525            Err(l.error())
526        }
527    }
528}
529
530fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<TrappableError> {
531    let wit_path = input.parse::<syn::LitStr>()?.value();
532    input.parse::<Token![=>]>()?;
533    let rust_type_name = input.parse::<syn::Path>()?.to_token_stream().to_string();
534    Ok(TrappableError {
535        wit_path,
536        rust_type_name,
537    })
538}
539
540fn with_field_parse(input: ParseStream<'_>) -> Result<(String, String)> {
541    let interface = input.parse::<syn::LitStr>()?.value();
542    input.parse::<Token![:]>()?;
543    let start = input.span();
544    let path = input.parse::<syn::Path>()?;
545
546    // It's not possible for the segments of a path to be empty
547    let span = start
548        .join(path.segments.last().unwrap().ident.span())
549        .unwrap_or(start);
550
551    let mut buf = String::new();
552    let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
553        if segment.arguments != syn::PathArguments::None {
554            return Err(Error::new(
555                span,
556                "Module path must not contain angles or parens",
557            ));
558        }
559
560        buf.push_str(&segment.ident.to_string());
561
562        Ok(())
563    };
564
565    if path.leading_colon.is_some() {
566        buf.push_str("::");
567    }
568
569    let mut segments = path.segments.into_iter();
570
571    if let Some(segment) = segments.next() {
572        append(&mut buf, segment)?;
573    }
574
575    for segment in segments {
576        buf.push_str("::");
577        append(&mut buf, segment)?;
578    }
579
580    Ok((interface, buf))
581}