wasmtime_wasi_nn/
wit.rs

1//! Implements the `wasi-nn` API for the WIT ("preview2") ABI.
2//!
3//! Note that `wasi-nn` is not yet included in an official "preview2" world
4//! (though it could be) so by "preview2" here we mean that this can be called
5//! with the component model's canonical ABI.
6//!
7//! This module exports its [`types`] for use throughout the crate and the
8//! [`ML`] object, which exposes [`ML::add_to_linker`]. To implement all of
9//! this, this module proceeds in steps:
10//! 1. generate all of the WIT glue code into a `generated::*` namespace
11//! 2. wire up the `generated::*` glue to the context state, delegating actual
12//!    computation to a [`Backend`]
13//! 3. convert some types
14//!
15//! [`Backend`]: crate::Backend
16//! [`types`]: crate::wit::types
17
18use crate::{Backend, Registry};
19use anyhow::anyhow;
20use std::collections::HashMap;
21use std::hash::Hash;
22use std::{fmt, str::FromStr};
23use wasmtime::component::{HasData, Resource, ResourceTable};
24
25/// Capture the state necessary for calling into the backend ML libraries.
26pub struct WasiNnCtx {
27    pub(crate) backends: HashMap<GraphEncoding, Backend>,
28    pub(crate) registry: Registry,
29}
30
31impl WasiNnCtx {
32    /// Make a new context from the default state.
33    pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
34        let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
35        Self { backends, registry }
36    }
37}
38
39/// A wrapper capturing the needed internal wasi-nn state.
40///
41/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`),
42/// this wrapper is not a `trait` but rather holds the references directly. This
43/// remove one layer of abstraction for simplicity only, and could be added back
44/// in the future if embedders need more control here.
45pub struct WasiNnView<'a> {
46    ctx: &'a mut WasiNnCtx,
47    table: &'a mut ResourceTable,
48}
49
50impl<'a> WasiNnView<'a> {
51    /// Create a new view into the wasi-nn state.
52    pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
53        Self { ctx, table }
54    }
55}
56
57/// A wasi-nn error; this appears on the Wasm side as a component model
58/// resource.
59#[derive(Debug)]
60pub struct Error {
61    code: ErrorCode,
62    data: anyhow::Error,
63}
64
65/// Construct an [`Error`] resource and immediately return it.
66///
67/// The WIT specification currently relies on "errors as resources;" this helper
68/// macro hides some of that complexity. If [#75] is adopted ("errors as
69/// records"), this macro is no longer necessary.
70///
71/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
72macro_rules! bail {
73    ($self:ident, $code:expr, $data:expr) => {
74        let e = Error {
75            code: $code,
76            data: $data.into(),
77        };
78        tracing::error!("failure: {e:?}");
79        let r = $self.table.push(e)?;
80        return Ok(Err(r));
81    };
82}
83
84impl From<wasmtime::component::ResourceTableError> for Error {
85    fn from(error: wasmtime::component::ResourceTableError) -> Self {
86        Self {
87            code: ErrorCode::Trap,
88            data: error.into(),
89        }
90    }
91}
92
93/// The list of error codes available to the `wasi-nn` API; this should match
94/// what is specified in WIT.
95#[derive(Debug)]
96pub enum ErrorCode {
97    /// Caller module passed an invalid argument.
98    InvalidArgument,
99    /// Invalid encoding.
100    InvalidEncoding,
101    /// The operation timed out.
102    Timeout,
103    /// Runtime error.
104    RuntimeError,
105    /// Unsupported operation.
106    UnsupportedOperation,
107    /// Graph is too large.
108    TooLarge,
109    /// Graph not found.
110    NotFound,
111    /// A runtime error that Wasmtime should trap on; this will not appear in
112    /// the WIT specification.
113    Trap,
114}
115
116/// Generate the traits and types from the `wasi-nn` WIT specification.
117pub(crate) mod generated_ {
118    wasmtime::component::bindgen!({
119        world: "ml",
120        path: "wit/wasi-nn.wit",
121        trappable_imports: true,
122        with: {
123            // Configure all WIT http resources to be defined types in this
124            // crate to use the `ResourceTable` helper methods.
125            "wasi:nn/graph/graph": crate::Graph,
126            "wasi:nn/tensor/tensor": crate::Tensor,
127            "wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
128            "wasi:nn/errors/error": super::Error,
129        },
130        trappable_error_type: {
131            "wasi:nn/errors/error" => super::Error,
132        },
133    });
134}
135use generated_::wasi::nn::{self as generated}; // Shortcut to the module containing the types we need.
136
137// Export the `types` used in this crate as well as `ML::add_to_linker`.
138pub mod types {
139    use super::generated;
140    pub use generated::errors::Error;
141    pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
142    pub use generated::inference::GraphExecutionContext;
143    pub use generated::tensor::{Tensor, TensorType};
144}
145pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
146pub use generated::inference::{GraphExecutionContext, NamedTensor};
147pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
148pub use generated_::Ml as ML;
149
150/// Add the WIT-based version of the `wasi-nn` API to a
151/// [`wasmtime::component::Linker`].
152pub fn add_to_linker<T: 'static>(
153    l: &mut wasmtime::component::Linker<T>,
154    f: fn(&mut T) -> WasiNnView<'_>,
155) -> anyhow::Result<()> {
156    generated::graph::add_to_linker::<_, HasWasiNnView>(l, f)?;
157    generated::tensor::add_to_linker::<_, HasWasiNnView>(l, f)?;
158    generated::inference::add_to_linker::<_, HasWasiNnView>(l, f)?;
159    generated::errors::add_to_linker::<_, HasWasiNnView>(l, f)?;
160    Ok(())
161}
162
163struct HasWasiNnView;
164
165impl HasData for HasWasiNnView {
166    type Data<'a> = WasiNnView<'a>;
167}
168
169impl generated::graph::Host for WasiNnView<'_> {
170    fn load(
171        &mut self,
172        builders: Vec<GraphBuilder>,
173        encoding: GraphEncoding,
174        target: ExecutionTarget,
175    ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
176        tracing::debug!("load {encoding:?} {target:?}");
177        if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
178            let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
179            match backend.load(&slices, target) {
180                Ok(graph) => {
181                    let graph = self.table.push(graph)?;
182                    Ok(Ok(graph))
183                }
184                Err(error) => {
185                    bail!(self, ErrorCode::RuntimeError, error);
186                }
187            }
188        } else {
189            bail!(
190                self,
191                ErrorCode::InvalidEncoding,
192                anyhow!("unable to find a backend for this encoding")
193            );
194        }
195    }
196
197    fn load_by_name(
198        &mut self,
199        name: String,
200    ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
201        use core::result::Result::*;
202        tracing::debug!("load by name {name:?}");
203        let registry = &self.ctx.registry;
204        if let Some(graph) = registry.get(&name) {
205            let graph = graph.clone();
206            let graph = self.table.push(graph)?;
207            Ok(Ok(graph))
208        } else {
209            bail!(
210                self,
211                ErrorCode::NotFound,
212                anyhow!("failed to find graph with name: {name}")
213            );
214        }
215    }
216}
217
218impl generated::graph::HostGraph for WasiNnView<'_> {
219    fn init_execution_context(
220        &mut self,
221        graph: Resource<Graph>,
222    ) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
223        use core::result::Result::*;
224        tracing::debug!("initialize execution context");
225        let graph = self.table.get(&graph)?;
226        match graph.init_execution_context() {
227            Ok(exec_context) => {
228                let exec_context = self.table.push(exec_context)?;
229                Ok(Ok(exec_context))
230            }
231            Err(error) => {
232                bail!(self, ErrorCode::RuntimeError, error);
233            }
234        }
235    }
236
237    fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
238        self.table.delete(graph)?;
239        Ok(())
240    }
241}
242
243impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
244    fn compute(
245        &mut self,
246        exec_context: Resource<GraphExecutionContext>,
247        inputs: Vec<NamedTensor>,
248    ) -> wasmtime::Result<Result<Vec<NamedTensor>, Resource<Error>>> {
249        tracing::debug!("compute with {} inputs", inputs.len());
250
251        let mut named_tensors = Vec::new();
252        for (name, tensor_resopurce) in inputs.iter() {
253            let tensor = self.table.get(&tensor_resopurce)?;
254            named_tensors.push(crate::backend::NamedTensor {
255                name: name.clone(),
256                tensor: tensor.clone(),
257            });
258        }
259
260        let exec_context = &mut self.table.get_mut(&exec_context)?;
261
262        match exec_context.compute_with_io(named_tensors) {
263            Ok(named_tensors) => {
264                let result = named_tensors
265                    .into_iter()
266                    .map(|crate::backend::NamedTensor { name, tensor }| {
267                        self.table.push(tensor).map(|resource| (name, resource))
268                    })
269                    .collect();
270
271                match result {
272                    Ok(tuples) => Ok(Ok(tuples)),
273                    Err(error) => {
274                        bail!(self, ErrorCode::RuntimeError, error);
275                    }
276                }
277            }
278            Err(error) => {
279                bail!(self, ErrorCode::RuntimeError, error);
280            }
281        }
282    }
283
284    fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
285        self.table.delete(exec_context)?;
286        Ok(())
287    }
288}
289
290impl generated::tensor::HostTensor for WasiNnView<'_> {
291    fn new(
292        &mut self,
293        dimensions: TensorDimensions,
294        ty: TensorType,
295        data: TensorData,
296    ) -> wasmtime::Result<Resource<Tensor>> {
297        let tensor = Tensor {
298            dimensions,
299            ty,
300            data,
301        };
302        let tensor = self.table.push(tensor)?;
303        Ok(tensor)
304    }
305
306    fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
307        let tensor = self.table.get(&tensor)?;
308        Ok(tensor.dimensions.clone())
309    }
310
311    fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
312        let tensor = self.table.get(&tensor)?;
313        Ok(tensor.ty)
314    }
315
316    fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
317        let tensor = self.table.get(&tensor)?;
318        Ok(tensor.data.clone())
319    }
320
321    fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
322        self.table.delete(tensor)?;
323        Ok(())
324    }
325}
326
327impl generated::errors::HostError for WasiNnView<'_> {
328    fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
329        let error = self.table.get(&error)?;
330        match error.code {
331            ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
332            ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
333            ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
334            ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
335            ErrorCode::UnsupportedOperation => {
336                Ok(generated::errors::ErrorCode::UnsupportedOperation)
337            }
338            ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
339            ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
340            ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
341        }
342    }
343
344    fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
345        let error = self.table.get(&error)?;
346        Ok(error.data.to_string())
347    }
348
349    fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
350        self.table.delete(error)?;
351        Ok(())
352    }
353}
354
355impl generated::errors::Host for WasiNnView<'_> {
356    fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
357        if matches!(err.code, ErrorCode::Trap) {
358            Err(err.data)
359        } else {
360            Ok(err)
361        }
362    }
363}
364
365impl generated::tensor::Host for WasiNnView<'_> {}
366impl generated::inference::Host for WasiNnView<'_> {}
367
368impl Hash for generated::graph::GraphEncoding {
369    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
370        self.to_string().hash(state)
371    }
372}
373
374impl fmt::Display for generated::graph::GraphEncoding {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        use generated::graph::GraphEncoding::*;
377        match self {
378            Openvino => write!(f, "openvino"),
379            Onnx => write!(f, "onnx"),
380            Pytorch => write!(f, "pytorch"),
381            Tensorflow => write!(f, "tensorflow"),
382            Tensorflowlite => write!(f, "tensorflowlite"),
383            Autodetect => write!(f, "autodetect"),
384            Ggml => write!(f, "ggml"),
385        }
386    }
387}
388
389impl FromStr for generated::graph::GraphEncoding {
390    type Err = GraphEncodingParseError;
391    fn from_str(s: &str) -> Result<Self, Self::Err> {
392        match s.to_lowercase().as_str() {
393            "openvino" => Ok(generated::graph::GraphEncoding::Openvino),
394            "onnx" => Ok(generated::graph::GraphEncoding::Onnx),
395            "pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
396            "tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
397            "tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
398            "autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
399            _ => Err(GraphEncodingParseError(s.into())),
400        }
401    }
402}
403#[derive(Debug)]
404pub struct GraphEncodingParseError(String);
405impl fmt::Display for GraphEncodingParseError {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        write!(f, "unknown graph encoding: {}", self.0)
408    }
409}
410impl std::error::Error for GraphEncodingParseError {}