wasmtime_wasi_nn/
lib.rs

1pub mod backend;
2mod registry;
3pub mod wit;
4pub mod witx;
5
6use crate::backend::{BackendError, Id, NamedTensor as BackendNamedTensor};
7use crate::wit::generated_::wasi::nn::tensor::TensorType;
8use anyhow::anyhow;
9use core::fmt;
10pub use registry::{GraphRegistry, InMemoryRegistry};
11use std::path::Path;
12use std::sync::Arc;
13
14/// Construct an in-memory registry from the available backends and a list of
15/// `(<backend name>, <graph directory>)`. This assumes graphs can be loaded
16/// from a local directory, which is a safe assumption currently for the current
17/// model types.
18pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
19    let mut backends = backend::list();
20    let mut registry = InMemoryRegistry::new();
21    for (kind, path) in preload_graphs {
22        let kind_ = kind.parse()?;
23        let backend = backends
24            .iter_mut()
25            .find(|b| b.encoding() == kind_)
26            .ok_or(anyhow!("unsupported backend: {}", kind))?
27            .as_dir_loadable()
28            .ok_or(anyhow!("{} does not support directory loading", kind))?;
29        registry.load(backend, Path::new(path))?;
30    }
31    Ok((backends, Registry::from(registry)))
32}
33
34/// A machine learning backend.
35pub struct Backend(Box<dyn backend::BackendInner>);
36impl std::ops::Deref for Backend {
37    type Target = dyn backend::BackendInner;
38    fn deref(&self) -> &Self::Target {
39        self.0.as_ref()
40    }
41}
42impl std::ops::DerefMut for Backend {
43    fn deref_mut(&mut self) -> &mut Self::Target {
44        self.0.as_mut()
45    }
46}
47impl<T: backend::BackendInner + 'static> From<T> for Backend {
48    fn from(value: T) -> Self {
49        Self(Box::new(value))
50    }
51}
52
53/// A backend-defined graph (i.e., ML model).
54#[derive(Clone)]
55pub struct Graph(Arc<dyn backend::BackendGraph>);
56impl From<Box<dyn backend::BackendGraph>> for Graph {
57    fn from(value: Box<dyn backend::BackendGraph>) -> Self {
58        Self(value.into())
59    }
60}
61impl std::ops::Deref for Graph {
62    type Target = dyn backend::BackendGraph;
63    fn deref(&self) -> &Self::Target {
64        self.0.as_ref()
65    }
66}
67
68/// A host-side tensor.
69///
70/// Eventually, this may be defined in each backend as they gain the ability to
71/// hold tensors on various devices (TODO:
72/// https://github.com/WebAssembly/wasi-nn/pull/70).
73#[derive(Clone, PartialEq)]
74pub struct Tensor {
75    pub dimensions: Vec<u32>,
76    pub ty: TensorType,
77    pub data: Vec<u8>,
78}
79impl fmt::Debug for Tensor {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_struct("Tensor")
82            .field("dimensions", &self.dimensions)
83            .field("ty", &self.ty)
84            .field("data (bytes)", &self.data.len())
85            .finish()
86    }
87}
88
89/// A backend-defined execution context.
90pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
91impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
92    fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
93        Self(value)
94    }
95}
96impl std::ops::Deref for ExecutionContext {
97    type Target = dyn backend::BackendExecutionContext;
98    fn deref(&self) -> &Self::Target {
99        self.0.as_ref()
100    }
101}
102impl std::ops::DerefMut for ExecutionContext {
103    fn deref_mut(&mut self) -> &mut Self::Target {
104        self.0.as_mut()
105    }
106}
107
108/// A container for graphs.
109pub struct Registry(Box<dyn GraphRegistry>);
110impl std::ops::Deref for Registry {
111    type Target = dyn GraphRegistry;
112    fn deref(&self) -> &Self::Target {
113        self.0.as_ref()
114    }
115}
116impl std::ops::DerefMut for Registry {
117    fn deref_mut(&mut self) -> &mut Self::Target {
118        self.0.as_mut()
119    }
120}
121impl<T> From<T> for Registry
122where
123    T: GraphRegistry + 'static,
124{
125    fn from(value: T) -> Self {
126        Self(Box::new(value))
127    }
128}
129
130impl ExecutionContext {
131    pub fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
132        self.0.set_input(id, tensor)
133    }
134
135    pub fn compute(&mut self) -> Result<(), BackendError> {
136        self.0.compute(None).map(|_| ())
137    }
138
139    pub fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
140        self.0.get_output(id)
141    }
142
143    pub fn compute_with_io(
144        &mut self,
145        inputs: Vec<BackendNamedTensor>,
146    ) -> Result<Vec<BackendNamedTensor>, BackendError> {
147        match self.0.compute(Some(inputs))? {
148            Some(outputs) => Ok(outputs),
149            None => Ok(Vec::new()),
150        }
151    }
152}
153
154impl Tensor {
155    pub fn new(dimensions: Vec<u32>, ty: TensorType, data: Vec<u8>) -> Self {
156        Self {
157            dimensions,
158            ty,
159            data,
160        }
161    }
162}