wasmtime_wasi_nn/
witx.rs

1//! Implements the `wasi-nn` API for the WITX ("preview1") ABI.
2//!
3//! `wasi-nn` was never included in the official "preview1" snapshot, but this
4//! module implements the ABI that is compatible with "preview1".
5//!
6//! The only export from this module is [`add_to_linker`]. To implement it, this
7//! module proceeds in steps:
8//! 1. generate all of the Wiggle glue code into a `generated::*` namespace
9//! 2. wire up the `generated::*` glue to the context state, delegating actual
10//!    computation to a `Backend`
11//! 3. wrap up with some conversions, i.e., from `generated::*` types to this crate's
12//!    [`types`].
13//!
14//! [`types`]: crate::wit::types
15
16use crate::backend::BackendError;
17use crate::backend::Id;
18use crate::wit::GraphEncoding;
19use crate::{Backend, ExecutionContext, Graph, Registry};
20use std::collections::HashMap;
21use std::hash::Hash;
22use thiserror::Error;
23use wiggle::{GuestError, GuestMemory, GuestPtr};
24
25pub use generated::wasi_ephemeral_nn::add_to_linker;
26
27pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
28type Result<T> = WasiNnResult<T>;
29type GraphId = u32;
30type GraphExecutionContextId = u32;
31
32/// Capture the state necessary for calling into the backend ML libraries.
33pub struct WasiNnCtx {
34    pub(crate) backends: HashMap<GraphEncoding, Backend>,
35    pub(crate) registry: Registry,
36    pub(crate) graphs: Table<GraphId, Graph>,
37    pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
38}
39
40impl WasiNnCtx {
41    /// Make a new context from the default state.
42    pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
43        let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
44        Self {
45            backends,
46            registry,
47            graphs: Table::default(),
48            executions: Table::default(),
49        }
50    }
51}
52
53/// Record handle entries in a table.
54pub struct Table<K, V> {
55    entries: HashMap<K, V>,
56    next_key: u32,
57}
58
59impl<K, V> Default for Table<K, V> {
60    fn default() -> Self {
61        Self {
62            entries: HashMap::new(),
63            next_key: 0,
64        }
65    }
66}
67
68impl<K, V> Table<K, V>
69where
70    K: Eq + Hash + From<u32> + Copy,
71{
72    pub fn insert(&mut self, value: V) -> K {
73        let key = self.use_next_key();
74        self.entries.insert(key, value);
75        key
76    }
77
78    pub fn get(&self, key: K) -> Option<&V> {
79        self.entries.get(&key)
80    }
81
82    pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
83        self.entries.get_mut(&key)
84    }
85
86    fn use_next_key(&mut self) -> K {
87        let current = self.next_key;
88        self.next_key += 1;
89        K::from(current)
90    }
91}
92
93/// Generate the traits and types from the `wasi-nn` WITX specification.
94mod generated {
95    use super::*;
96    wiggle::from_witx!({
97        witx: ["$WASI_ROOT/wasi-nn.witx"],
98        errors: { nn_errno => WasiNnError }
99    });
100
101    /// Additionally, we must let Wiggle know which of our error codes
102    /// represents a successful operation.
103    impl wiggle::GuestErrorType for types::NnErrno {
104        fn success() -> Self {
105            Self::Success
106        }
107    }
108
109    /// Convert the host errors to their WITX-generated type.
110    impl types::UserErrorConversion for WasiNnCtx {
111        fn nn_errno_from_wasi_nn_error(
112            &mut self,
113            e: WasiNnError,
114        ) -> anyhow::Result<types::NnErrno> {
115            tracing::debug!("host error: {:?}", e);
116            match e {
117                WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError),
118                WasiNnError::GuestError(_) => unimplemented!("guest error conversion"),
119                WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation),
120                WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge),
121            }
122        }
123    }
124}
125
126/// Wire up the WITX-generated trait to the `wasi-nn` host state.
127impl generated::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
128    fn load(
129        &mut self,
130        memory: &mut GuestMemory<'_>,
131        builders: generated::types::GraphBuilderArray,
132        encoding: generated::types::GraphEncoding,
133        target: generated::types::ExecutionTarget,
134    ) -> Result<generated::types::Graph> {
135        let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) {
136            // Retrieve all of the "builder lists" from the Wasm memory (see
137            // $graph_builder_array) as slices for a backend to operate on.
138            let mut slices = vec![];
139            for builder in builders.iter() {
140                let builder = memory.read(builder?)?;
141                let slice = memory.as_slice(builder)?.expect(
142                    "cannot use with shared memories; \
143                     see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
144                );
145                slices.push(slice);
146            }
147            let slice_refs = slices.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
148            backend.load(&slice_refs, target.into())?
149        } else {
150            return Err(UsageError::InvalidEncoding(encoding.into()).into());
151        };
152        let graph_id = self.graphs.insert(graph);
153        Ok(graph_id.into())
154    }
155
156    fn load_by_name(
157        &mut self,
158        memory: &mut GuestMemory<'_>,
159        name: wiggle::GuestPtr<str>,
160    ) -> Result<generated::types::Graph> {
161        let name = memory.as_str(name)?.unwrap();
162        if let Some(graph) = self.registry.get_mut(&name) {
163            let graph_id = self.graphs.insert(graph.clone().into());
164            Ok(graph_id.into())
165        } else {
166            return Err(UsageError::NotFound(name.to_string()).into());
167        }
168    }
169
170    fn init_execution_context(
171        &mut self,
172        _memory: &mut GuestMemory<'_>,
173        graph_id: generated::types::Graph,
174    ) -> Result<generated::types::GraphExecutionContext> {
175        let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) {
176            graph.init_execution_context()?
177        } else {
178            return Err(UsageError::InvalidGraphHandle.into());
179        };
180
181        let exec_context_id = self.executions.insert(exec_context);
182        Ok(exec_context_id.into())
183    }
184
185    fn set_input(
186        &mut self,
187        memory: &mut GuestMemory<'_>,
188        exec_context_id: generated::types::GraphExecutionContext,
189        index: u32,
190        tensor: &generated::types::Tensor,
191    ) -> Result<()> {
192        if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
193            let tensor = crate::wit::types::Tensor {
194                dimensions: memory.to_vec(tensor.dimensions)?,
195                ty: tensor.type_.into(),
196                data: memory.to_vec(tensor.data)?,
197            };
198            Ok(exec_context.set_input(Id::Index(index), &tensor)?)
199        } else {
200            Err(UsageError::InvalidGraphHandle.into())
201        }
202    }
203
204    fn compute(
205        &mut self,
206        _memory: &mut GuestMemory<'_>,
207        exec_context_id: generated::types::GraphExecutionContext,
208    ) -> Result<()> {
209        if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
210            Ok(exec_context.compute()?)
211        } else {
212            Err(UsageError::InvalidExecutionContextHandle.into())
213        }
214    }
215
216    fn get_output(
217        &mut self,
218        memory: &mut GuestMemory<'_>,
219        exec_context_id: generated::types::GraphExecutionContext,
220        index: u32,
221        out_buffer: GuestPtr<u8>,
222        out_buffer_max_size: u32,
223    ) -> Result<u32> {
224        if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
225            let tensor = exec_context.get_output(Id::Index(index))?;
226            let destination = memory
227                .as_slice_mut(out_buffer.as_array(out_buffer_max_size))?
228                .expect(
229                    "cannot use with shared memories; \
230                     see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
231                );
232            if tensor.data.len() > destination.len() {
233                Err(WasiNnError::NotEnoughMemory(tensor.data.len()))
234            } else {
235                destination[..tensor.data.len()].copy_from_slice(&tensor.data);
236                Ok(tensor.data.len() as u32)
237            }
238        } else {
239            Err(UsageError::InvalidGraphHandle.into())
240        }
241    }
242}
243
244// Implement some conversion from `witx::types::*` to this crate's version.
245
246impl From<generated::types::ExecutionTarget> for crate::wit::types::ExecutionTarget {
247    fn from(value: generated::types::ExecutionTarget) -> Self {
248        match value {
249            generated::types::ExecutionTarget::Cpu => crate::wit::types::ExecutionTarget::Cpu,
250            generated::types::ExecutionTarget::Gpu => crate::wit::types::ExecutionTarget::Gpu,
251            generated::types::ExecutionTarget::Tpu => crate::wit::types::ExecutionTarget::Tpu,
252        }
253    }
254}
255impl From<generated::types::GraphEncoding> for crate::wit::types::GraphEncoding {
256    fn from(value: generated::types::GraphEncoding) -> Self {
257        match value {
258            generated::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino,
259            generated::types::GraphEncoding::Onnx => crate::wit::types::GraphEncoding::Onnx,
260            generated::types::GraphEncoding::Tensorflow => {
261                crate::wit::types::GraphEncoding::Tensorflow
262            }
263            generated::types::GraphEncoding::Pytorch => crate::wit::types::GraphEncoding::Pytorch,
264            generated::types::GraphEncoding::Tensorflowlite => {
265                crate::wit::types::GraphEncoding::Tensorflowlite
266            }
267            generated::types::GraphEncoding::Autodetect => {
268                crate::wit::types::GraphEncoding::Autodetect
269            }
270        }
271    }
272}
273impl From<generated::types::TensorType> for crate::wit::types::TensorType {
274    fn from(value: generated::types::TensorType) -> Self {
275        match value {
276            generated::types::TensorType::F16 => crate::wit::types::TensorType::Fp16,
277            generated::types::TensorType::F32 => crate::wit::types::TensorType::Fp32,
278            generated::types::TensorType::U8 => crate::wit::types::TensorType::U8,
279            generated::types::TensorType::I32 => crate::wit::types::TensorType::I32,
280            generated::types::TensorType::I64 => crate::wit::types::TensorType::I64,
281            generated::types::TensorType::F64 => crate::wit::types::TensorType::Fp64,
282        }
283    }
284}
285
286/// Possible errors while interacting with [WasiNnCtx].
287#[derive(Debug, Error)]
288pub enum WasiNnError {
289    #[error("backend error")]
290    BackendError(#[from] BackendError),
291    #[error("guest error")]
292    GuestError(#[from] GuestError),
293    #[error("usage error")]
294    UsageError(#[from] UsageError),
295    #[error("not enough memory: requested {0} bytes")]
296    NotEnoughMemory(usize),
297}
298
299#[derive(Debug, Error)]
300pub enum UsageError {
301    #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
302    InvalidEncoding(GraphEncoding),
303    #[error("Invalid graph handle; has it been loaded?")]
304    InvalidGraphHandle,
305    #[error("Invalid execution context handle; has it been initialized?")]
306    InvalidExecutionContextHandle,
307    #[error("No graph found with name: {0}")]
308    NotFound(String),
309}