wasmtime_wasi_nn/
lib.rs

1pub mod backend;
2mod registry;
3pub mod wit;
4pub mod witx;
5
6use anyhow::anyhow;
7use core::fmt;
8pub use registry::{GraphRegistry, InMemoryRegistry};
9use std::path::Path;
10use std::sync::Arc;
11
12/// Construct an in-memory registry from the available backends and a list of
13/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
14/// from a local directory, which is a safe assumption currently for the current
15/// model types.
16pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
17    let mut backends = backend::list();
18    let mut registry = InMemoryRegistry::new();
19    for (kind, path) in preload_graphs {
20        let kind_ = kind.parse()?;
21        let backend = backends
22            .iter_mut()
23            .find(|b| b.encoding() == kind_)
24            .ok_or(anyhow!("unsupported backend: {}", kind))?
25            .as_dir_loadable()
26            .ok_or(anyhow!("{} does not support directory loading", kind))?;
27        registry.load(backend, Path::new(path))?;
28    }
29    Ok((backends, Registry::from(registry)))
30}
31
32/// A machine learning backend.
33pub struct Backend(Box<dyn backend::BackendInner>);
34impl std::ops::Deref for Backend {
35    type Target = dyn backend::BackendInner;
36    fn deref(&self) -> &Self::Target {
37        self.0.as_ref()
38    }
39}
40impl std::ops::DerefMut for Backend {
41    fn deref_mut(&mut self) -> &mut Self::Target {
42        self.0.as_mut()
43    }
44}
45impl<T: backend::BackendInner + 'static> From<T> for Backend {
46    fn from(value: T) -> Self {
47        Self(Box::new(value))
48    }
49}
50
51/// A backend-defined graph (i.e., ML model).
52#[derive(Clone)]
53pub struct Graph(Arc<dyn backend::BackendGraph>);
54impl From<Box<dyn backend::BackendGraph>> for Graph {
55    fn from(value: Box<dyn backend::BackendGraph>) -> Self {
56        Self(value.into())
57    }
58}
59impl std::ops::Deref for Graph {
60    type Target = dyn backend::BackendGraph;
61    fn deref(&self) -> &Self::Target {
62        self.0.as_ref()
63    }
64}
65
66/// A host-side tensor.
67///
68/// Eventually, this may be defined in each backend as they gain the ability to
69/// hold tensors on various devices (TODO:
70/// https://github.com/WebAssembly/wasi-nn/pull/70).
71#[derive(Clone, PartialEq)]
72pub struct Tensor {
73    dimensions: Vec<u32>,
74    ty: wit::TensorType,
75    data: Vec<u8>,
76}
77impl fmt::Debug for Tensor {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("Tensor")
80            .field("dimensions", &self.dimensions)
81            .field("ty", &self.ty)
82            .field("data (bytes)", &self.data.len())
83            .finish()
84    }
85}
86
87/// A backend-defined execution context.
88pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
89impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
90    fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
91        Self(value)
92    }
93}
94impl std::ops::Deref for ExecutionContext {
95    type Target = dyn backend::BackendExecutionContext;
96    fn deref(&self) -> &Self::Target {
97        self.0.as_ref()
98    }
99}
100impl std::ops::DerefMut for ExecutionContext {
101    fn deref_mut(&mut self) -> &mut Self::Target {
102        self.0.as_mut()
103    }
104}
105
106/// A container for graphs.
107pub struct Registry(Box<dyn GraphRegistry>);
108impl std::ops::Deref for Registry {
109    type Target = dyn GraphRegistry;
110    fn deref(&self) -> &Self::Target {
111        self.0.as_ref()
112    }
113}
114impl std::ops::DerefMut for Registry {
115    fn deref_mut(&mut self) -> &mut Self::Target {
116        self.0.as_mut()
117    }
118}
119impl<T> From<T> for Registry
120where
121    T: GraphRegistry + 'static,
122{
123    fn from(value: T) -> Self {
124        Self(Box::new(value))
125    }
126}