1use crate::backend::BackendError;
17use crate::backend::Id;
18use crate::wit::GraphEncoding;
19use crate::{Backend, ExecutionContext, Graph, Registry};
20use std::collections::HashMap;
21use std::hash::Hash;
22use thiserror::Error;
23use wiggle::{GuestError, GuestMemory, GuestPtr};
24
25pub use generated::wasi_ephemeral_nn::add_to_linker;
26
27pub(crate) type WasiNnResult<T> = std::result::Result<T, WasiNnError>;
28type Result<T> = WasiNnResult<T>;
29type GraphId = u32;
30type GraphExecutionContextId = u32;
31
32pub struct WasiNnCtx {
34 pub(crate) backends: HashMap<GraphEncoding, Backend>,
35 pub(crate) registry: Registry,
36 pub(crate) graphs: Table<GraphId, Graph>,
37 pub(crate) executions: Table<GraphExecutionContextId, ExecutionContext>,
38}
39
40impl WasiNnCtx {
41 pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
43 let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
44 Self {
45 backends,
46 registry,
47 graphs: Table::default(),
48 executions: Table::default(),
49 }
50 }
51}
52
53pub struct Table<K, V> {
55 entries: HashMap<K, V>,
56 next_key: u32,
57}
58
59impl<K, V> Default for Table<K, V> {
60 fn default() -> Self {
61 Self {
62 entries: HashMap::new(),
63 next_key: 0,
64 }
65 }
66}
67
68impl<K, V> Table<K, V>
69where
70 K: Eq + Hash + From<u32> + Copy,
71{
72 pub fn insert(&mut self, value: V) -> K {
73 let key = self.use_next_key();
74 self.entries.insert(key, value);
75 key
76 }
77
78 pub fn get(&self, key: K) -> Option<&V> {
79 self.entries.get(&key)
80 }
81
82 pub fn get_mut(&mut self, key: K) -> Option<&mut V> {
83 self.entries.get_mut(&key)
84 }
85
86 fn use_next_key(&mut self) -> K {
87 let current = self.next_key;
88 self.next_key += 1;
89 K::from(current)
90 }
91}
92
93mod generated {
95 use super::*;
96 wiggle::from_witx!({
97 witx: ["$WASI_ROOT/wasi-nn.witx"],
98 errors: { nn_errno => WasiNnError }
99 });
100
101 impl wiggle::GuestErrorType for types::NnErrno {
104 fn success() -> Self {
105 Self::Success
106 }
107 }
108
109 impl types::UserErrorConversion for WasiNnCtx {
111 fn nn_errno_from_wasi_nn_error(
112 &mut self,
113 e: WasiNnError,
114 ) -> anyhow::Result<types::NnErrno> {
115 tracing::debug!("host error: {:?}", e);
116 match e {
117 WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError),
118 WasiNnError::GuestError(_) => unimplemented!("guest error conversion"),
119 WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation),
120 WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge),
121 }
122 }
123 }
124}
125
126impl generated::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx {
128 fn load(
129 &mut self,
130 memory: &mut GuestMemory<'_>,
131 builders: generated::types::GraphBuilderArray,
132 encoding: generated::types::GraphEncoding,
133 target: generated::types::ExecutionTarget,
134 ) -> Result<generated::types::Graph> {
135 let graph = if let Some(backend) = self.backends.get_mut(&encoding.into()) {
136 let mut slices = vec![];
139 for builder in builders.iter() {
140 let builder = memory.read(builder?)?;
141 let slice = memory.as_slice(builder)?.expect(
142 "cannot use with shared memories; \
143 see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
144 );
145 slices.push(slice);
146 }
147 let slice_refs = slices.iter().map(|s| s.as_ref()).collect::<Vec<_>>();
148 backend.load(&slice_refs, target.into())?
149 } else {
150 return Err(UsageError::InvalidEncoding(encoding.into()).into());
151 };
152 let graph_id = self.graphs.insert(graph);
153 Ok(graph_id.into())
154 }
155
156 fn load_by_name(
157 &mut self,
158 memory: &mut GuestMemory<'_>,
159 name: wiggle::GuestPtr<str>,
160 ) -> Result<generated::types::Graph> {
161 let name = memory.as_str(name)?.unwrap();
162 if let Some(graph) = self.registry.get_mut(&name) {
163 let graph_id = self.graphs.insert(graph.clone().into());
164 Ok(graph_id.into())
165 } else {
166 return Err(UsageError::NotFound(name.to_string()).into());
167 }
168 }
169
170 fn init_execution_context(
171 &mut self,
172 _memory: &mut GuestMemory<'_>,
173 graph_id: generated::types::Graph,
174 ) -> Result<generated::types::GraphExecutionContext> {
175 let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id.into()) {
176 graph.init_execution_context()?
177 } else {
178 return Err(UsageError::InvalidGraphHandle.into());
179 };
180
181 let exec_context_id = self.executions.insert(exec_context);
182 Ok(exec_context_id.into())
183 }
184
185 fn set_input(
186 &mut self,
187 memory: &mut GuestMemory<'_>,
188 exec_context_id: generated::types::GraphExecutionContext,
189 index: u32,
190 tensor: &generated::types::Tensor,
191 ) -> Result<()> {
192 if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
193 let tensor = crate::wit::types::Tensor {
194 dimensions: memory.to_vec(tensor.dimensions)?,
195 ty: tensor.type_.into(),
196 data: memory.to_vec(tensor.data)?,
197 };
198 Ok(exec_context.set_input(Id::Index(index), &tensor)?)
199 } else {
200 Err(UsageError::InvalidGraphHandle.into())
201 }
202 }
203
204 fn compute(
205 &mut self,
206 _memory: &mut GuestMemory<'_>,
207 exec_context_id: generated::types::GraphExecutionContext,
208 ) -> Result<()> {
209 if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
210 Ok(exec_context.compute()?)
211 } else {
212 Err(UsageError::InvalidExecutionContextHandle.into())
213 }
214 }
215
216 fn get_output(
217 &mut self,
218 memory: &mut GuestMemory<'_>,
219 exec_context_id: generated::types::GraphExecutionContext,
220 index: u32,
221 out_buffer: GuestPtr<u8>,
222 out_buffer_max_size: u32,
223 ) -> Result<u32> {
224 if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) {
225 let tensor = exec_context.get_output(Id::Index(index))?;
226 let destination = memory
227 .as_slice_mut(out_buffer.as_array(out_buffer_max_size))?
228 .expect(
229 "cannot use with shared memories; \
230 see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)",
231 );
232 if tensor.data.len() > destination.len() {
233 Err(WasiNnError::NotEnoughMemory(tensor.data.len()))
234 } else {
235 destination[..tensor.data.len()].copy_from_slice(&tensor.data);
236 Ok(tensor.data.len() as u32)
237 }
238 } else {
239 Err(UsageError::InvalidGraphHandle.into())
240 }
241 }
242}
243
244impl From<generated::types::ExecutionTarget> for crate::wit::types::ExecutionTarget {
247 fn from(value: generated::types::ExecutionTarget) -> Self {
248 match value {
249 generated::types::ExecutionTarget::Cpu => crate::wit::types::ExecutionTarget::Cpu,
250 generated::types::ExecutionTarget::Gpu => crate::wit::types::ExecutionTarget::Gpu,
251 generated::types::ExecutionTarget::Tpu => crate::wit::types::ExecutionTarget::Tpu,
252 }
253 }
254}
255impl From<generated::types::GraphEncoding> for crate::wit::types::GraphEncoding {
256 fn from(value: generated::types::GraphEncoding) -> Self {
257 match value {
258 generated::types::GraphEncoding::Openvino => crate::wit::types::GraphEncoding::Openvino,
259 generated::types::GraphEncoding::Onnx => crate::wit::types::GraphEncoding::Onnx,
260 generated::types::GraphEncoding::Tensorflow => {
261 crate::wit::types::GraphEncoding::Tensorflow
262 }
263 generated::types::GraphEncoding::Pytorch => crate::wit::types::GraphEncoding::Pytorch,
264 generated::types::GraphEncoding::Tensorflowlite => {
265 crate::wit::types::GraphEncoding::Tensorflowlite
266 }
267 generated::types::GraphEncoding::Autodetect => {
268 crate::wit::types::GraphEncoding::Autodetect
269 }
270 }
271 }
272}
273impl From<generated::types::TensorType> for crate::wit::types::TensorType {
274 fn from(value: generated::types::TensorType) -> Self {
275 match value {
276 generated::types::TensorType::F16 => crate::wit::types::TensorType::Fp16,
277 generated::types::TensorType::F32 => crate::wit::types::TensorType::Fp32,
278 generated::types::TensorType::U8 => crate::wit::types::TensorType::U8,
279 generated::types::TensorType::I32 => crate::wit::types::TensorType::I32,
280 generated::types::TensorType::I64 => crate::wit::types::TensorType::I64,
281 generated::types::TensorType::F64 => crate::wit::types::TensorType::Fp64,
282 }
283 }
284}
285
286#[derive(Debug, Error)]
288pub enum WasiNnError {
289 #[error("backend error")]
290 BackendError(#[from] BackendError),
291 #[error("guest error")]
292 GuestError(#[from] GuestError),
293 #[error("usage error")]
294 UsageError(#[from] UsageError),
295 #[error("not enough memory: requested {0} bytes")]
296 NotEnoughMemory(usize),
297}
298
299#[derive(Debug, Error)]
300pub enum UsageError {
301 #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")]
302 InvalidEncoding(GraphEncoding),
303 #[error("Invalid graph handle; has it been loaded?")]
304 InvalidGraphHandle,
305 #[error("Invalid execution context handle; has it been initialized?")]
306 InvalidExecutionContextHandle,
307 #[error("No graph found with name: {0}")]
308 NotFound(String),
309}