1use crate::{KeyValuePair, WasiNnGraph};
9use anyhow::{Result, bail};
10use clap::builder::{StringValueParser, TypedValueParser, ValueParserFactory};
11use clap::error::{Error, ErrorKind};
12use serde::de::{self, Visitor};
13use std::time::Duration;
14use std::{fmt, marker};
15
16const IGNORED_NUMBER_CHARS: [char; 1] = ['_'];
18
19#[macro_export]
20macro_rules! wasmtime_option_group {
21 (
22 $(#[$attr:meta])*
23 pub struct $opts:ident {
24 $(
25 $(#[doc = $doc:tt])*
26 $(#[serde($serde_attr:meta)])*
27 pub $opt:ident: $container:ident<$payload:ty>,
28 )+
29
30 $(
31 #[prefixed = $prefix:tt]
32 $(#[serde($serde_attr2:meta)])*
33 $(#[doc = $prefixed_doc:tt])*
34 pub $prefixed:ident: Vec<(String, Option<String>)>,
35 )?
36 }
37 enum $option:ident {
38 ...
39 }
40 ) => {
41 #[derive(Default, Debug)]
42 $(#[$attr])*
43 pub struct $opts {
44 $(
45 $(#[serde($serde_attr)])*
46 pub $opt: $container<$payload>,
47 )+
48 $(
49 $(#[serde($serde_attr2)])*
50 pub $prefixed: Vec<(String, Option<String>)>,
51 )?
52 }
53
54 #[derive(Clone, PartialEq)]
55 #[expect(non_camel_case_types, reason = "macro-generated code")]
56 enum $option {
57 $(
58 $opt($payload),
59 )+
60 $(
61 $prefixed(String, Option<String>),
62 )?
63 }
64
65 impl $crate::opt::WasmtimeOption for $option {
66 const OPTIONS: &'static [$crate::opt::OptionDesc<$option>] = &[
67 $(
68 $crate::opt::OptionDesc {
69 name: $crate::opt::OptName::Name(stringify!($opt)),
70 parse: |_, s| {
71 Ok($option::$opt(
72 $crate::opt::WasmtimeOptionValue::parse(s)?
73 ))
74 },
75 val_help: <$payload as $crate::opt::WasmtimeOptionValue>::VAL_HELP,
76 docs: concat!($($doc, "\n",)*),
77 },
78 )+
79 $(
80 $crate::opt::OptionDesc {
81 name: $crate::opt::OptName::Prefix($prefix),
82 parse: |name, val| {
83 Ok($option::$prefixed(
84 name.to_string(),
85 val.map(|v| v.to_string()),
86 ))
87 },
88 val_help: "[=val]",
89 docs: concat!($($prefixed_doc, "\n",)*),
90 },
91 )?
92 ];
93 }
94
95 impl core::fmt::Display for $option {
96 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97 match self {
98 $(
99 $option::$opt(val) => {
100 write!(f, "{}=", stringify!($opt).replace('_', "-"))?;
101 $crate::opt::WasmtimeOptionValue::display(val, f)
102 }
103 )+
104 $(
105 $option::$prefixed(key, val) => {
106 write!(f, "{}-{key}", stringify!($prefixed))?;
107 if let Some(val) = val {
108 write!(f, "={val}")?;
109 }
110 Ok(())
111 }
112 )?
113 }
114 }
115 }
116
117 impl $opts {
118 fn configure_with(&mut self, opts: &[$crate::opt::CommaSeparated<$option>]) {
119 for opt in opts.iter().flat_map(|o| o.0.iter()) {
120 match opt {
121 $(
122 $option::$opt(val) => {
123 $crate::opt::OptionContainer::push(&mut self.$opt, val.clone());
124 }
125 )+
126 $(
127 $option::$prefixed(key, val) => self.$prefixed.push((key.clone(), val.clone())),
128 )?
129 }
130 }
131 }
132
133 fn to_options(&self) -> Vec<$option> {
134 let mut ret = Vec::new();
135 $(
136 for item in $crate::opt::OptionContainer::get(&self.$opt) {
137 ret.push($option::$opt(item.clone()));
138 }
139 )+
140 $(
141 for (key,val) in self.$prefixed.iter() {
142 ret.push($option::$prefixed(key.clone(), val.clone()));
143 }
144 )?
145 ret
146 }
147 }
148 };
149}
150
151#[derive(Clone, Debug, PartialEq)]
153pub struct CommaSeparated<T>(pub Vec<T>);
154
155impl<T> ValueParserFactory for CommaSeparated<T>
156where
157 T: WasmtimeOption,
158{
159 type Parser = CommaSeparatedParser<T>;
160
161 fn value_parser() -> CommaSeparatedParser<T> {
162 CommaSeparatedParser(marker::PhantomData)
163 }
164}
165
166#[derive(Clone)]
167pub struct CommaSeparatedParser<T>(marker::PhantomData<T>);
168
169impl<T> TypedValueParser for CommaSeparatedParser<T>
170where
171 T: WasmtimeOption,
172{
173 type Value = CommaSeparated<T>;
174
175 fn parse_ref(
176 &self,
177 cmd: &clap::Command,
178 arg: Option<&clap::Arg>,
179 value: &std::ffi::OsStr,
180 ) -> Result<Self::Value, Error> {
181 let val = StringValueParser::new().parse_ref(cmd, arg, value)?;
182
183 let options = T::OPTIONS;
184 let arg = arg.expect("should always have an argument");
185 let arg_long = arg.get_long().expect("should have a long name specified");
186 let arg_short = arg.get_short().expect("should have a short name specified");
187
188 if val == "help" {
191 let mut max = 0;
192 for d in options {
193 max = max.max(d.name.display_string().len() + d.val_help.len());
194 }
195 println!("Available {arg_long} options:\n");
196 for d in options {
197 print!(
198 " -{arg_short} {:>1$}",
199 d.name.display_string(),
200 max - d.val_help.len()
201 );
202 print!("{}", d.val_help);
203 print!(" --");
204 if val == "help" {
205 for line in d.docs.lines().map(|s| s.trim()) {
206 if line.is_empty() {
207 break;
208 }
209 print!(" {line}");
210 }
211 println!();
212 } else {
213 println!();
214 for line in d.docs.lines().map(|s| s.trim()) {
215 let line = line.trim();
216 println!(" {line}");
217 }
218 }
219 }
220 println!("\npass `-{arg_short} help-long` to see longer-form explanations");
221 std::process::exit(0);
222 }
223 if val == "help-long" {
224 println!("Available {arg_long} options:\n");
225 for d in options {
226 println!(
227 " -{arg_short} {}{} --",
228 d.name.display_string(),
229 d.val_help
230 );
231 println!();
232 for line in d.docs.lines().map(|s| s.trim()) {
233 let line = line.trim();
234 println!(" {line}");
235 }
236 }
237 std::process::exit(0);
238 }
239
240 let mut result = Vec::new();
241 for val in val.split(',') {
242 let mut iter = val.splitn(2, '=');
244 let key = iter.next().unwrap();
245 let key_val = iter.next();
246
247 let option = options
249 .iter()
250 .filter_map(|d| match d.name {
251 OptName::Name(s) => {
252 let s = s.replace('_', "-");
253 if s == key { Some((d, s)) } else { None }
254 }
255 OptName::Prefix(s) => {
256 let name = key.strip_prefix(s)?.strip_prefix("-")?;
257 Some((d, name.to_string()))
258 }
259 })
260 .next();
261
262 let (desc, key) = match option {
263 Some(pair) => pair,
264 None => {
265 let err = Error::raw(
266 ErrorKind::InvalidValue,
267 format!("unknown -{arg_short} / --{arg_long} option: {key}\n"),
268 );
269 return Err(err.with_cmd(cmd));
270 }
271 };
272
273 result.push((desc.parse)(&key, key_val).map_err(|e| {
274 Error::raw(
275 ErrorKind::InvalidValue,
276 format!("failed to parse -{arg_short} option `{val}`: {e:?}\n"),
277 )
278 .with_cmd(cmd)
279 })?)
280 }
281
282 Ok(CommaSeparated(result))
283 }
284}
285
286pub trait WasmtimeOption: Sized + Send + Sync + Clone + 'static {
289 const OPTIONS: &'static [OptionDesc<Self>];
290}
291
292pub struct OptionDesc<T> {
293 pub name: OptName,
294 pub docs: &'static str,
295 pub parse: fn(&str, Option<&str>) -> Result<T>,
296 pub val_help: &'static str,
297}
298
299pub enum OptName {
300 Name(&'static str),
303
304 Prefix(&'static str),
306}
307
308impl OptName {
309 fn display_string(&self) -> String {
310 match self {
311 OptName::Name(s) => s.replace('_', "-"),
312 OptName::Prefix(s) => format!("{s}-<KEY>"),
313 }
314 }
315}
316
317pub trait WasmtimeOptionValue: Sized {
320 const VAL_HELP: &'static str;
322
323 fn parse(val: Option<&str>) -> Result<Self>;
325
326 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
328}
329
330impl WasmtimeOptionValue for String {
331 const VAL_HELP: &'static str = "=val";
332 fn parse(val: Option<&str>) -> Result<Self> {
333 match val {
334 Some(val) => Ok(val.to_string()),
335 None => bail!("value must be specified with `key=val` syntax"),
336 }
337 }
338
339 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340 f.write_str(self)
341 }
342}
343
344impl WasmtimeOptionValue for u32 {
345 const VAL_HELP: &'static str = "=N";
346 fn parse(val: Option<&str>) -> Result<Self> {
347 let val = String::parse(val)?.replace(IGNORED_NUMBER_CHARS, "");
348 match val.strip_prefix("0x") {
349 Some(hex) => Ok(u32::from_str_radix(hex, 16)?),
350 None => Ok(val.parse()?),
351 }
352 }
353
354 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 write!(f, "{self}")
356 }
357}
358
359impl WasmtimeOptionValue for u64 {
360 const VAL_HELP: &'static str = "=N";
361 fn parse(val: Option<&str>) -> Result<Self> {
362 let val = String::parse(val)?.replace(IGNORED_NUMBER_CHARS, "");
363 match val.strip_prefix("0x") {
364 Some(hex) => Ok(u64::from_str_radix(hex, 16)?),
365 None => Ok(val.parse()?),
366 }
367 }
368
369 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 write!(f, "{self}")
371 }
372}
373
374impl WasmtimeOptionValue for usize {
375 const VAL_HELP: &'static str = "=N";
376 fn parse(val: Option<&str>) -> Result<Self> {
377 let val = String::parse(val)?.replace(IGNORED_NUMBER_CHARS, "");
378 match val.strip_prefix("0x") {
379 Some(hex) => Ok(usize::from_str_radix(hex, 16)?),
380 None => Ok(val.parse()?),
381 }
382 }
383
384 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385 write!(f, "{self}")
386 }
387}
388
389impl WasmtimeOptionValue for bool {
390 const VAL_HELP: &'static str = "[=y|n]";
391 fn parse(val: Option<&str>) -> Result<Self> {
392 match val {
393 None | Some("y") | Some("yes") | Some("true") => Ok(true),
394 Some("n") | Some("no") | Some("false") => Ok(false),
395 Some(s) => bail!("unknown boolean flag `{s}`, only yes,no,<nothing> accepted"),
396 }
397 }
398
399 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400 if *self {
401 f.write_str("y")
402 } else {
403 f.write_str("n")
404 }
405 }
406}
407
408impl WasmtimeOptionValue for Duration {
409 const VAL_HELP: &'static str = "=N|Ns|Nms|..";
410 fn parse(val: Option<&str>) -> Result<Duration> {
411 let s = String::parse(val)?;
412 if let Ok(val) = s.parse() {
414 return Ok(Duration::from_secs(val));
415 }
416
417 if let Some(num) = s.strip_suffix("s") {
418 if let Ok(val) = num.parse() {
419 return Ok(Duration::from_secs(val));
420 }
421 }
422 if let Some(num) = s.strip_suffix("ms") {
423 if let Ok(val) = num.parse() {
424 return Ok(Duration::from_millis(val));
425 }
426 }
427 if let Some(num) = s.strip_suffix("us").or(s.strip_suffix("μs")) {
428 if let Ok(val) = num.parse() {
429 return Ok(Duration::from_micros(val));
430 }
431 }
432 if let Some(num) = s.strip_suffix("ns") {
433 if let Ok(val) = num.parse() {
434 return Ok(Duration::from_nanos(val));
435 }
436 }
437
438 bail!("failed to parse duration: {s}")
439 }
440
441 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 let subsec = self.subsec_nanos();
443 if subsec == 0 {
444 write!(f, "{}s", self.as_secs())
445 } else if subsec % 1_000 == 0 {
446 write!(f, "{}μs", self.as_micros())
447 } else if subsec % 1_000_000 == 0 {
448 write!(f, "{}ms", self.as_millis())
449 } else {
450 write!(f, "{}ns", self.as_nanos())
451 }
452 }
453}
454
455impl WasmtimeOptionValue for wasmtime::OptLevel {
456 const VAL_HELP: &'static str = "=0|1|2|s";
457 fn parse(val: Option<&str>) -> Result<Self> {
458 match String::parse(val)?.as_str() {
459 "0" => Ok(wasmtime::OptLevel::None),
460 "1" => Ok(wasmtime::OptLevel::Speed),
461 "2" => Ok(wasmtime::OptLevel::Speed),
462 "s" => Ok(wasmtime::OptLevel::SpeedAndSize),
463 other => bail!(
464 "unknown optimization level `{}`, only 0,1,2,s accepted",
465 other
466 ),
467 }
468 }
469
470 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
471 match *self {
472 wasmtime::OptLevel::None => f.write_str("0"),
473 wasmtime::OptLevel::Speed => f.write_str("2"),
474 wasmtime::OptLevel::SpeedAndSize => f.write_str("s"),
475 _ => unreachable!(),
476 }
477 }
478}
479
480impl WasmtimeOptionValue for wasmtime::RegallocAlgorithm {
481 const VAL_HELP: &'static str = "=backtracking|single-pass";
482 fn parse(val: Option<&str>) -> Result<Self> {
483 match String::parse(val)?.as_str() {
484 "backtracking" => Ok(wasmtime::RegallocAlgorithm::Backtracking),
485 other => bail!(
486 "unknown regalloc algorithm`{}`, only backtracking,single-pass accepted",
487 other
488 ),
489 }
490 }
491
492 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
493 match *self {
494 wasmtime::RegallocAlgorithm::Backtracking => f.write_str("backtracking"),
495 _ => unreachable!(),
496 }
497 }
498}
499
500impl WasmtimeOptionValue for wasmtime::Strategy {
501 const VAL_HELP: &'static str = "=winch|cranelift";
502 fn parse(val: Option<&str>) -> Result<Self> {
503 match String::parse(val)?.as_str() {
504 "cranelift" => Ok(wasmtime::Strategy::Cranelift),
505 "winch" => Ok(wasmtime::Strategy::Winch),
506 other => bail!("unknown compiler `{other}` only `cranelift` and `winch` accepted",),
507 }
508 }
509
510 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511 match *self {
512 wasmtime::Strategy::Cranelift => f.write_str("cranelift"),
513 wasmtime::Strategy::Winch => f.write_str("winch"),
514 _ => unreachable!(),
515 }
516 }
517}
518
519impl WasmtimeOptionValue for wasmtime::Collector {
520 const VAL_HELP: &'static str = "=drc|null";
521 fn parse(val: Option<&str>) -> Result<Self> {
522 match String::parse(val)?.as_str() {
523 "drc" => Ok(wasmtime::Collector::DeferredReferenceCounting),
524 "null" => Ok(wasmtime::Collector::Null),
525 other => bail!("unknown collector `{other}` only `drc` and `null` accepted",),
526 }
527 }
528
529 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 match *self {
531 wasmtime::Collector::DeferredReferenceCounting => f.write_str("drc"),
532 wasmtime::Collector::Null => f.write_str("null"),
533 _ => unreachable!(),
534 }
535 }
536}
537
538impl WasmtimeOptionValue for wasmtime::MpkEnabled {
539 const VAL_HELP: &'static str = "[=y|n|auto]";
540 fn parse(val: Option<&str>) -> Result<Self> {
541 match val {
542 None | Some("y") | Some("yes") | Some("true") => Ok(wasmtime::MpkEnabled::Enable),
543 Some("n") | Some("no") | Some("false") => Ok(wasmtime::MpkEnabled::Disable),
544 Some("auto") => Ok(wasmtime::MpkEnabled::Auto),
545 Some(s) => bail!("unknown mpk flag `{s}`, only yes,no,auto,<nothing> accepted"),
546 }
547 }
548
549 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
550 match *self {
551 wasmtime::MpkEnabled::Enable => f.write_str("y"),
552 wasmtime::MpkEnabled::Disable => f.write_str("n"),
553 wasmtime::MpkEnabled::Auto => f.write_str("auto"),
554 }
555 }
556}
557
558impl WasmtimeOptionValue for WasiNnGraph {
559 const VAL_HELP: &'static str = "=<format>::<dir>";
560 fn parse(val: Option<&str>) -> Result<Self> {
561 let val = String::parse(val)?;
562 let mut parts = val.splitn(2, "::");
563 Ok(WasiNnGraph {
564 format: parts.next().unwrap().to_string(),
565 dir: match parts.next() {
566 Some(part) => part.into(),
567 None => bail!("graph does not contain `::` separator for directory"),
568 },
569 })
570 }
571
572 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573 write!(f, "{}::{}", self.format, self.dir)
574 }
575}
576
577impl WasmtimeOptionValue for KeyValuePair {
578 const VAL_HELP: &'static str = "=<name>=<val>";
579 fn parse(val: Option<&str>) -> Result<Self> {
580 let val = String::parse(val)?;
581 let mut parts = val.splitn(2, "=");
582 Ok(KeyValuePair {
583 key: parts.next().unwrap().to_string(),
584 value: match parts.next() {
585 Some(part) => part.into(),
586 None => "".to_string(),
587 },
588 })
589 }
590
591 fn display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
592 f.write_str(&self.key)?;
593 if !self.value.is_empty() {
594 f.write_str("=")?;
595 f.write_str(&self.value)?;
596 }
597 Ok(())
598 }
599}
600
601pub trait OptionContainer<T> {
602 fn push(&mut self, val: T);
603 fn get<'a>(&'a self) -> impl Iterator<Item = &'a T>
604 where
605 T: 'a;
606}
607
608impl<T> OptionContainer<T> for Option<T> {
609 fn push(&mut self, val: T) {
610 *self = Some(val);
611 }
612 fn get<'a>(&'a self) -> impl Iterator<Item = &'a T>
613 where
614 T: 'a,
615 {
616 self.iter()
617 }
618}
619
620impl<T> OptionContainer<T> for Vec<T> {
621 fn push(&mut self, val: T) {
622 Vec::push(self, val);
623 }
624 fn get<'a>(&'a self) -> impl Iterator<Item = &'a T>
625 where
626 T: 'a,
627 {
628 self.iter()
629 }
630}
631
632struct ToStringVisitor {}
637
638impl<'de> Visitor<'de> for ToStringVisitor {
639 type Value = String;
640
641 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
642 write!(formatter, "&str, u64, or i64")
643 }
644
645 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
646 where
647 E: de::Error,
648 {
649 Ok(s.to_owned())
650 }
651
652 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
653 where
654 E: de::Error,
655 {
656 Ok(v.to_string())
657 }
658
659 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
660 where
661 E: de::Error,
662 {
663 Ok(v.to_string())
664 }
665}
666
667pub(crate) fn cli_parse_wrapper<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
669where
670 T: WasmtimeOptionValue,
671 D: serde::Deserializer<'de>,
672{
673 let to_string_visitor = ToStringVisitor {};
674 let str = deserializer.deserialize_any(to_string_visitor)?;
675
676 T::parse(Some(&str))
677 .map(Some)
678 .map_err(serde::de::Error::custom)
679}
680
681#[cfg(test)]
682mod tests {
683 use super::WasmtimeOptionValue;
684
685 #[test]
686 fn numbers_with_underscores() {
687 assert!(<u32 as WasmtimeOptionValue>::parse(Some("123")).is_ok_and(|v| v == 123));
688 assert!(<u32 as WasmtimeOptionValue>::parse(Some("1_2_3")).is_ok_and(|v| v == 123));
689 }
690}