1use crate::backend::Id;
19use crate::{Backend, Registry};
20use anyhow::anyhow;
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::{fmt, str::FromStr};
24use wasmtime::component::{Resource, ResourceTable};
25
26pub struct WasiNnCtx {
28 pub(crate) backends: HashMap<GraphEncoding, Backend>,
29 pub(crate) registry: Registry,
30}
31
32impl WasiNnCtx {
33 pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
35 let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
36 Self { backends, registry }
37 }
38}
39
40pub struct WasiNnView<'a> {
47 ctx: &'a mut WasiNnCtx,
48 table: &'a mut ResourceTable,
49}
50
51impl<'a> WasiNnView<'a> {
52 pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
54 Self { ctx, table }
55 }
56}
57
58#[derive(Debug)]
61pub struct Error {
62 code: ErrorCode,
63 data: anyhow::Error,
64}
65
66macro_rules! bail {
74 ($self:ident, $code:expr, $data:expr) => {
75 let e = Error {
76 code: $code,
77 data: $data.into(),
78 };
79 tracing::error!("failure: {e:?}");
80 let r = $self.table.push(e)?;
81 return Ok(Err(r));
82 };
83}
84
85impl From<wasmtime::component::ResourceTableError> for Error {
86 fn from(error: wasmtime::component::ResourceTableError) -> Self {
87 Self {
88 code: ErrorCode::Trap,
89 data: error.into(),
90 }
91 }
92}
93
94#[derive(Debug)]
97pub enum ErrorCode {
98 InvalidArgument,
100 InvalidEncoding,
102 Timeout,
104 RuntimeError,
106 UnsupportedOperation,
108 TooLarge,
110 NotFound,
112 Trap,
115}
116
117mod generated_ {
119 wasmtime::component::bindgen!({
120 world: "ml",
121 path: "wit/wasi-nn.wit",
122 trappable_imports: true,
123 with: {
124 "wasi:nn/graph/graph": crate::Graph,
127 "wasi:nn/tensor/tensor": crate::Tensor,
128 "wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
129 "wasi:nn/errors/error": super::Error,
130 },
131 trappable_error_type: {
132 "wasi:nn/errors/error" => super::Error,
133 },
134 });
135}
136use generated_::wasi::nn::{self as generated}; pub mod types {
140 use super::generated;
141 pub use generated::errors::Error;
142 pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
143 pub use generated::inference::GraphExecutionContext;
144 pub use generated::tensor::{Tensor, TensorType};
145}
146pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
147pub use generated::inference::GraphExecutionContext;
148pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
149pub use generated_::Ml as ML;
150
151pub fn add_to_linker<T>(
154 l: &mut wasmtime::component::Linker<T>,
155 f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static,
156) -> anyhow::Result<()> {
157 generated::graph::add_to_linker_get_host(l, f)?;
158 generated::tensor::add_to_linker_get_host(l, f)?;
159 generated::inference::add_to_linker_get_host(l, f)?;
160 generated::errors::add_to_linker_get_host(l, f)?;
161 Ok(())
162}
163
164impl generated::graph::Host for WasiNnView<'_> {
165 fn load(
166 &mut self,
167 builders: Vec<GraphBuilder>,
168 encoding: GraphEncoding,
169 target: ExecutionTarget,
170 ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
171 tracing::debug!("load {encoding:?} {target:?}");
172 if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
173 let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
174 match backend.load(&slices, target.into()) {
175 Ok(graph) => {
176 let graph = self.table.push(graph)?;
177 Ok(Ok(graph))
178 }
179 Err(error) => {
180 bail!(self, ErrorCode::RuntimeError, error);
181 }
182 }
183 } else {
184 bail!(
185 self,
186 ErrorCode::InvalidEncoding,
187 anyhow!("unable to find a backend for this encoding")
188 );
189 }
190 }
191
192 fn load_by_name(
193 &mut self,
194 name: String,
195 ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
196 use core::result::Result::*;
197 tracing::debug!("load by name {name:?}");
198 let registry = &self.ctx.registry;
199 if let Some(graph) = registry.get(&name) {
200 let graph = graph.clone();
201 let graph = self.table.push(graph)?;
202 Ok(Ok(graph))
203 } else {
204 bail!(
205 self,
206 ErrorCode::NotFound,
207 anyhow!("failed to find graph with name: {name}")
208 );
209 }
210 }
211}
212
213impl generated::graph::HostGraph for WasiNnView<'_> {
214 fn init_execution_context(
215 &mut self,
216 graph: Resource<Graph>,
217 ) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
218 use core::result::Result::*;
219 tracing::debug!("initialize execution context");
220 let graph = self.table.get(&graph)?;
221 match graph.init_execution_context() {
222 Ok(exec_context) => {
223 let exec_context = self.table.push(exec_context)?;
224 Ok(Ok(exec_context))
225 }
226 Err(error) => {
227 bail!(self, ErrorCode::RuntimeError, error);
228 }
229 }
230 }
231
232 fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
233 self.table.delete(graph)?;
234 Ok(())
235 }
236}
237
238impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
239 fn set_input(
240 &mut self,
241 exec_context: Resource<GraphExecutionContext>,
242 name: String,
243 tensor: Resource<Tensor>,
244 ) -> wasmtime::Result<Result<(), Resource<Error>>> {
245 let tensor = self.table.get(&tensor)?;
246 tracing::debug!("set input {name:?}: {tensor:?}");
247 let tensor = tensor.clone(); let exec_context = self.table.get_mut(&exec_context)?;
249 if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
250 bail!(self, ErrorCode::InvalidArgument, error);
251 } else {
252 Ok(Ok(()))
253 }
254 }
255
256 fn compute(
257 &mut self,
258 exec_context: Resource<GraphExecutionContext>,
259 ) -> wasmtime::Result<Result<(), Resource<Error>>> {
260 let exec_context = &mut self.table.get_mut(&exec_context)?;
261 tracing::debug!("compute");
262 match exec_context.compute() {
263 Ok(()) => Ok(Ok(())),
264 Err(error) => {
265 bail!(self, ErrorCode::RuntimeError, error);
266 }
267 }
268 }
269
270 fn get_output(
271 &mut self,
272 exec_context: Resource<GraphExecutionContext>,
273 name: String,
274 ) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
275 let exec_context = self.table.get_mut(&exec_context)?;
276 tracing::debug!("get output {name:?}");
277 match exec_context.get_output(Id::Name(name)) {
278 Ok(tensor) => {
279 let tensor = self.table.push(tensor)?;
280 Ok(Ok(tensor))
281 }
282 Err(error) => {
283 bail!(self, ErrorCode::RuntimeError, error);
284 }
285 }
286 }
287
288 fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
289 self.table.delete(exec_context)?;
290 Ok(())
291 }
292}
293
294impl generated::tensor::HostTensor for WasiNnView<'_> {
295 fn new(
296 &mut self,
297 dimensions: TensorDimensions,
298 ty: TensorType,
299 data: TensorData,
300 ) -> wasmtime::Result<Resource<Tensor>> {
301 let tensor = Tensor {
302 dimensions,
303 ty,
304 data,
305 };
306 let tensor = self.table.push(tensor)?;
307 Ok(tensor)
308 }
309
310 fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
311 let tensor = self.table.get(&tensor)?;
312 Ok(tensor.dimensions.clone())
313 }
314
315 fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
316 let tensor = self.table.get(&tensor)?;
317 Ok(tensor.ty)
318 }
319
320 fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
321 let tensor = self.table.get(&tensor)?;
322 Ok(tensor.data.clone())
323 }
324
325 fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
326 self.table.delete(tensor)?;
327 Ok(())
328 }
329}
330
331impl generated::errors::HostError for WasiNnView<'_> {
332 fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
333 let error = self.table.get(&error)?;
334 match error.code {
335 ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
336 ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
337 ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
338 ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
339 ErrorCode::UnsupportedOperation => {
340 Ok(generated::errors::ErrorCode::UnsupportedOperation)
341 }
342 ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
343 ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
344 ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
345 }
346 }
347
348 fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
349 let error = self.table.get(&error)?;
350 Ok(error.data.to_string())
351 }
352
353 fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
354 self.table.delete(error)?;
355 Ok(())
356 }
357}
358
359impl generated::errors::Host for WasiNnView<'_> {
360 fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
361 if matches!(err.code, ErrorCode::Trap) {
362 Err(err.data)
363 } else {
364 Ok(err)
365 }
366 }
367}
368
369impl generated::tensor::Host for WasiNnView<'_> {}
370impl generated::inference::Host for WasiNnView<'_> {}
371
372impl Hash for generated::graph::GraphEncoding {
373 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
374 self.to_string().hash(state)
375 }
376}
377
378impl fmt::Display for generated::graph::GraphEncoding {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 use generated::graph::GraphEncoding::*;
381 match self {
382 Openvino => write!(f, "openvino"),
383 Onnx => write!(f, "onnx"),
384 Pytorch => write!(f, "pytorch"),
385 Tensorflow => write!(f, "tensorflow"),
386 Tensorflowlite => write!(f, "tensorflowlite"),
387 Autodetect => write!(f, "autodetect"),
388 Ggml => write!(f, "ggml"),
389 }
390 }
391}
392
393impl FromStr for generated::graph::GraphEncoding {
394 type Err = GraphEncodingParseError;
395 fn from_str(s: &str) -> Result<Self, Self::Err> {
396 match s.to_lowercase().as_str() {
397 "openvino" => Ok(generated::graph::GraphEncoding::Openvino),
398 "onnx" => Ok(generated::graph::GraphEncoding::Onnx),
399 "pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
400 "tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
401 "tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
402 "autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
403 _ => Err(GraphEncodingParseError(s.into())),
404 }
405 }
406}
407#[derive(Debug)]
408pub struct GraphEncodingParseError(String);
409impl fmt::Display for GraphEncodingParseError {
410 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
411 write!(f, "unknown graph encoding: {}", self.0)
412 }
413}
414impl std::error::Error for GraphEncodingParseError {}