wasmtime_wasi_nn/
wit.rs

1//! Implements the `wasi-nn` API for the WIT ("preview2") ABI.
2//!
3//! Note that `wasi-nn` is not yet included in an official "preview2" world
4//! (though it could be) so by "preview2" here we mean that this can be called
5//! with the component model's canonical ABI.
6//!
7//! This module exports its [`types`] for use throughout the crate and the
8//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of
9//! this, this module proceeds in steps:
10//! 1. generate all of the WIT glue code into a `generated::*` namespace
11//! 2. wire up the `generated::*` glue to the context state, delegating actual
12//!    computation to a [`Backend`]
13//! 3. convert some types
14//!
15//! [`Backend`]: crate::Backend
16//! [`types`]: crate::wit::types
17
18use crate::backend::Id;
19use crate::{Backend, Registry};
20use anyhow::anyhow;
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::{fmt, str::FromStr};
24use wasmtime::component::{Resource, ResourceTable};
25
26/// Capture the state necessary for calling into the backend ML libraries.
27pub struct WasiNnCtx {
28    pub(crate) backends: HashMap<GraphEncoding, Backend>,
29    pub(crate) registry: Registry,
30}
31
32impl WasiNnCtx {
33    /// Make a new context from the default state.
34    pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
35        let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
36        Self { backends, registry }
37    }
38}
39
40/// A wrapper capturing the needed internal wasi-nn state.
41///
42/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
43/// this wrapper is not a `trait` but rather holds the references directly. This
44/// remove one layer of abstraction for simplicity only, and could be added back
45/// in the future if embedders need more control here.
46pub struct WasiNnView<'a> {
47    ctx: &'a mut WasiNnCtx,
48    table: &'a mut ResourceTable,
49}
50
51impl<'a> WasiNnView<'a> {
52    /// Create a new view into the wasi-nn state.
53    pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
54        Self { ctx, table }
55    }
56}
57
58/// A wasi-nn error; this appears on the Wasm side as a component model
59/// resource.
60#[derive(Debug)]
61pub struct Error {
62    code: ErrorCode,
63    data: anyhow::Error,
64}
65
66/// Construct an [`Error`] resource and immediately return it.
67///
68/// The WIT specification currently relies on "errors as resources;" this helper
69/// macro hides some of that complexity. If [#75] is adopted ("errors as
70/// records"), this macro is no longer necessary.
71///
72/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
73macro_rules! bail {
74    ($self:ident, $code:expr, $data:expr) => {
75        let e = Error {
76            code: $code,
77            data: $data.into(),
78        };
79        tracing::error!("failure: {e:?}");
80        let r = $self.table.push(e)?;
81        return Ok(Err(r));
82    };
83}
84
85impl From<wasmtime::component::ResourceTableError> for Error {
86    fn from(error: wasmtime::component::ResourceTableError) -> Self {
87        Self {
88            code: ErrorCode::Trap,
89            data: error.into(),
90        }
91    }
92}
93
94/// The list of error codes available to the `wasi-nn` API; this should match
95/// what is specified in WIT.
96#[derive(Debug)]
97pub enum ErrorCode {
98    /// Caller module passed an invalid argument.
99    InvalidArgument,
100    /// Invalid encoding.
101    InvalidEncoding,
102    /// The operation timed out.
103    Timeout,
104    /// Runtime error.
105    RuntimeError,
106    /// Unsupported operation.
107    UnsupportedOperation,
108    /// Graph is too large.
109    TooLarge,
110    /// Graph not found.
111    NotFound,
112    /// A runtime error that Wasmtime should trap on; this will not appear in
113    /// the WIT specification.
114    Trap,
115}
116
117/// Generate the traits and types from the `wasi-nn` WIT specification.
118mod generated_ {
119    wasmtime::component::bindgen!({
120        world: "ml",
121        path: "wit/wasi-nn.wit",
122        trappable_imports: true,
123        with: {
124            // Configure all WIT http resources to be defined types in this
125            // crate to use the `ResourceTable` helper methods.
126            "wasi:nn/graph/graph": crate::Graph,
127            "wasi:nn/tensor/tensor": crate::Tensor,
128            "wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
129            "wasi:nn/errors/error": super::Error,
130        },
131        trappable_error_type: {
132            "wasi:nn/errors/error" => super::Error,
133        },
134    });
135}
136use generated_::wasi::nn::{self as generated}; // Shortcut to the module containing the types we need.
137
138// Export the `types` used in this crate as well as `ML::add_to_linker`.
139pub mod types {
140    use super::generated;
141    pub use generated::errors::Error;
142    pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
143    pub use generated::inference::GraphExecutionContext;
144    pub use generated::tensor::{Tensor, TensorType};
145}
146pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
147pub use generated::inference::GraphExecutionContext;
148pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
149pub use generated_::Ml as ML;
150
151/// Add the WIT-based version of the `wasi-nn` API to a
152/// [`wasmtime::component::Linker`].
153pub fn add_to_linker<T>(
154    l: &mut wasmtime::component::Linker<T>,
155    f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static,
156) -> anyhow::Result<()> {
157    generated::graph::add_to_linker_get_host(l, f)?;
158    generated::tensor::add_to_linker_get_host(l, f)?;
159    generated::inference::add_to_linker_get_host(l, f)?;
160    generated::errors::add_to_linker_get_host(l, f)?;
161    Ok(())
162}
163
164impl generated::graph::Host for WasiNnView<'_> {
165    fn load(
166        &mut self,
167        builders: Vec<GraphBuilder>,
168        encoding: GraphEncoding,
169        target: ExecutionTarget,
170    ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
171        tracing::debug!("load {encoding:?} {target:?}");
172        if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
173            let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
174            match backend.load(&slices, target.into()) {
175                Ok(graph) => {
176                    let graph = self.table.push(graph)?;
177                    Ok(Ok(graph))
178                }
179                Err(error) => {
180                    bail!(self, ErrorCode::RuntimeError, error);
181                }
182            }
183        } else {
184            bail!(
185                self,
186                ErrorCode::InvalidEncoding,
187                anyhow!("unable to find a backend for this encoding")
188            );
189        }
190    }
191
192    fn load_by_name(
193        &mut self,
194        name: String,
195    ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
196        use core::result::Result::*;
197        tracing::debug!("load by name {name:?}");
198        let registry = &self.ctx.registry;
199        if let Some(graph) = registry.get(&name) {
200            let graph = graph.clone();
201            let graph = self.table.push(graph)?;
202            Ok(Ok(graph))
203        } else {
204            bail!(
205                self,
206                ErrorCode::NotFound,
207                anyhow!("failed to find graph with name: {name}")
208            );
209        }
210    }
211}
212
213impl generated::graph::HostGraph for WasiNnView<'_> {
214    fn init_execution_context(
215        &mut self,
216        graph: Resource<Graph>,
217    ) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
218        use core::result::Result::*;
219        tracing::debug!("initialize execution context");
220        let graph = self.table.get(&graph)?;
221        match graph.init_execution_context() {
222            Ok(exec_context) => {
223                let exec_context = self.table.push(exec_context)?;
224                Ok(Ok(exec_context))
225            }
226            Err(error) => {
227                bail!(self, ErrorCode::RuntimeError, error);
228            }
229        }
230    }
231
232    fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
233        self.table.delete(graph)?;
234        Ok(())
235    }
236}
237
238impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
239    fn set_input(
240        &mut self,
241        exec_context: Resource<GraphExecutionContext>,
242        name: String,
243        tensor: Resource<Tensor>,
244    ) -> wasmtime::Result<Result<(), Resource<Error>>> {
245        let tensor = self.table.get(&tensor)?;
246        tracing::debug!("set input {name:?}: {tensor:?}");
247        let tensor = tensor.clone(); // TODO: avoid copying the tensor
248        let exec_context = self.table.get_mut(&exec_context)?;
249        if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
250            bail!(self, ErrorCode::InvalidArgument, error);
251        } else {
252            Ok(Ok(()))
253        }
254    }
255
256    fn compute(
257        &mut self,
258        exec_context: Resource<GraphExecutionContext>,
259    ) -> wasmtime::Result<Result<(), Resource<Error>>> {
260        let exec_context = &mut self.table.get_mut(&exec_context)?;
261        tracing::debug!("compute");
262        match exec_context.compute() {
263            Ok(()) => Ok(Ok(())),
264            Err(error) => {
265                bail!(self, ErrorCode::RuntimeError, error);
266            }
267        }
268    }
269
270    fn get_output(
271        &mut self,
272        exec_context: Resource<GraphExecutionContext>,
273        name: String,
274    ) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
275        let exec_context = self.table.get_mut(&exec_context)?;
276        tracing::debug!("get output {name:?}");
277        match exec_context.get_output(Id::Name(name)) {
278            Ok(tensor) => {
279                let tensor = self.table.push(tensor)?;
280                Ok(Ok(tensor))
281            }
282            Err(error) => {
283                bail!(self, ErrorCode::RuntimeError, error);
284            }
285        }
286    }
287
288    fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
289        self.table.delete(exec_context)?;
290        Ok(())
291    }
292}
293
294impl generated::tensor::HostTensor for WasiNnView<'_> {
295    fn new(
296        &mut self,
297        dimensions: TensorDimensions,
298        ty: TensorType,
299        data: TensorData,
300    ) -> wasmtime::Result<Resource<Tensor>> {
301        let tensor = Tensor {
302            dimensions,
303            ty,
304            data,
305        };
306        let tensor = self.table.push(tensor)?;
307        Ok(tensor)
308    }
309
310    fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
311        let tensor = self.table.get(&tensor)?;
312        Ok(tensor.dimensions.clone())
313    }
314
315    fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
316        let tensor = self.table.get(&tensor)?;
317        Ok(tensor.ty)
318    }
319
320    fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
321        let tensor = self.table.get(&tensor)?;
322        Ok(tensor.data.clone())
323    }
324
325    fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
326        self.table.delete(tensor)?;
327        Ok(())
328    }
329}
330
331impl generated::errors::HostError for WasiNnView<'_> {
332    fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
333        let error = self.table.get(&error)?;
334        match error.code {
335            ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
336            ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
337            ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
338            ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
339            ErrorCode::UnsupportedOperation => {
340                Ok(generated::errors::ErrorCode::UnsupportedOperation)
341            }
342            ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
343            ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
344            ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
345        }
346    }
347
348    fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
349        let error = self.table.get(&error)?;
350        Ok(error.data.to_string())
351    }
352
353    fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
354        self.table.delete(error)?;
355        Ok(())
356    }
357}
358
359impl generated::errors::Host for WasiNnView<'_> {
360    fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
361        if matches!(err.code, ErrorCode::Trap) {
362            Err(err.data)
363        } else {
364            Ok(err)
365        }
366    }
367}
368
369impl generated::tensor::Host for WasiNnView<'_> {}
370impl generated::inference::Host for WasiNnView<'_> {}
371
372impl Hash for generated::graph::GraphEncoding {
373    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
374        self.to_string().hash(state)
375    }
376}
377
378impl fmt::Display for generated::graph::GraphEncoding {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        use generated::graph::GraphEncoding::*;
381        match self {
382            Openvino => write!(f, "openvino"),
383            Onnx => write!(f, "onnx"),
384            Pytorch => write!(f, "pytorch"),
385            Tensorflow => write!(f, "tensorflow"),
386            Tensorflowlite => write!(f, "tensorflowlite"),
387            Autodetect => write!(f, "autodetect"),
388            Ggml => write!(f, "ggml"),
389        }
390    }
391}
392
393impl FromStr for generated::graph::GraphEncoding {
394    type Err = GraphEncodingParseError;
395    fn from_str(s: &str) -> Result<Self, Self::Err> {
396        match s.to_lowercase().as_str() {
397            "openvino" => Ok(generated::graph::GraphEncoding::Openvino),
398            "onnx" => Ok(generated::graph::GraphEncoding::Onnx),
399            "pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
400            "tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
401            "tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
402            "autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
403            _ => Err(GraphEncodingParseError(s.into())),
404        }
405    }
406}
407#[derive(Debug)]
408pub struct GraphEncodingParseError(String);
409impl fmt::Display for GraphEncodingParseError {
410    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411        write!(f, "unknown graph encoding: {}", self.0)
412    }
413}
414impl std::error::Error for GraphEncodingParseError {}