1use crate::{Backend, Registry};
19use anyhow::anyhow;
20use std::collections::HashMap;
21use std::hash::Hash;
22use std::{fmt, str::FromStr};
23use wasmtime::component::{HasData, Resource, ResourceTable};
24
25pub struct WasiNnCtx {
27 pub(crate) backends: HashMap<GraphEncoding, Backend>,
28 pub(crate) registry: Registry,
29}
30
31impl WasiNnCtx {
32 pub fn new(backends: impl IntoIterator<Item = Backend>, registry: Registry) -> Self {
34 let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect();
35 Self { backends, registry }
36 }
37}
38
39pub struct WasiNnView<'a> {
46 ctx: &'a mut WasiNnCtx,
47 table: &'a mut ResourceTable,
48}
49
50impl<'a> WasiNnView<'a> {
51 pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self {
53 Self { ctx, table }
54 }
55}
56
57#[derive(Debug)]
60pub struct Error {
61 code: ErrorCode,
62 data: anyhow::Error,
63}
64
65macro_rules! bail {
73 ($self:ident, $code:expr, $data:expr) => {
74 let e = Error {
75 code: $code,
76 data: $data.into(),
77 };
78 tracing::error!("failure: {e:?}");
79 let r = $self.table.push(e)?;
80 return Ok(Err(r));
81 };
82}
83
84impl From<wasmtime::component::ResourceTableError> for Error {
85 fn from(error: wasmtime::component::ResourceTableError) -> Self {
86 Self {
87 code: ErrorCode::Trap,
88 data: error.into(),
89 }
90 }
91}
92
93#[derive(Debug)]
96pub enum ErrorCode {
97 InvalidArgument,
99 InvalidEncoding,
101 Timeout,
103 RuntimeError,
105 UnsupportedOperation,
107 TooLarge,
109 NotFound,
111 Trap,
114}
115
116pub(crate) mod generated_ {
118 wasmtime::component::bindgen!({
119 world: "ml",
120 path: "wit/wasi-nn.wit",
121 trappable_imports: true,
122 with: {
123 "wasi:nn/graph/graph": crate::Graph,
126 "wasi:nn/tensor/tensor": crate::Tensor,
127 "wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
128 "wasi:nn/errors/error": super::Error,
129 },
130 trappable_error_type: {
131 "wasi:nn/errors/error" => super::Error,
132 },
133 });
134}
135use generated_::wasi::nn::{self as generated}; pub mod types {
139 use super::generated;
140 pub use generated::errors::Error;
141 pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
142 pub use generated::inference::GraphExecutionContext;
143 pub use generated::tensor::{Tensor, TensorType};
144}
145pub use generated::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding};
146pub use generated::inference::{GraphExecutionContext, NamedTensor};
147pub use generated::tensor::{Tensor, TensorData, TensorDimensions, TensorType};
148pub use generated_::Ml as ML;
149
150pub fn add_to_linker<T: 'static>(
153 l: &mut wasmtime::component::Linker<T>,
154 f: fn(&mut T) -> WasiNnView<'_>,
155) -> anyhow::Result<()> {
156 generated::graph::add_to_linker::<_, HasWasiNnView>(l, f)?;
157 generated::tensor::add_to_linker::<_, HasWasiNnView>(l, f)?;
158 generated::inference::add_to_linker::<_, HasWasiNnView>(l, f)?;
159 generated::errors::add_to_linker::<_, HasWasiNnView>(l, f)?;
160 Ok(())
161}
162
163struct HasWasiNnView;
164
165impl HasData for HasWasiNnView {
166 type Data<'a> = WasiNnView<'a>;
167}
168
169impl generated::graph::Host for WasiNnView<'_> {
170 fn load(
171 &mut self,
172 builders: Vec<GraphBuilder>,
173 encoding: GraphEncoding,
174 target: ExecutionTarget,
175 ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
176 tracing::debug!("load {encoding:?} {target:?}");
177 if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
178 let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
179 match backend.load(&slices, target) {
180 Ok(graph) => {
181 let graph = self.table.push(graph)?;
182 Ok(Ok(graph))
183 }
184 Err(error) => {
185 bail!(self, ErrorCode::RuntimeError, error);
186 }
187 }
188 } else {
189 bail!(
190 self,
191 ErrorCode::InvalidEncoding,
192 anyhow!("unable to find a backend for this encoding")
193 );
194 }
195 }
196
197 fn load_by_name(
198 &mut self,
199 name: String,
200 ) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
201 use core::result::Result::*;
202 tracing::debug!("load by name {name:?}");
203 let registry = &self.ctx.registry;
204 if let Some(graph) = registry.get(&name) {
205 let graph = graph.clone();
206 let graph = self.table.push(graph)?;
207 Ok(Ok(graph))
208 } else {
209 bail!(
210 self,
211 ErrorCode::NotFound,
212 anyhow!("failed to find graph with name: {name}")
213 );
214 }
215 }
216}
217
218impl generated::graph::HostGraph for WasiNnView<'_> {
219 fn init_execution_context(
220 &mut self,
221 graph: Resource<Graph>,
222 ) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
223 use core::result::Result::*;
224 tracing::debug!("initialize execution context");
225 let graph = self.table.get(&graph)?;
226 match graph.init_execution_context() {
227 Ok(exec_context) => {
228 let exec_context = self.table.push(exec_context)?;
229 Ok(Ok(exec_context))
230 }
231 Err(error) => {
232 bail!(self, ErrorCode::RuntimeError, error);
233 }
234 }
235 }
236
237 fn drop(&mut self, graph: Resource<Graph>) -> wasmtime::Result<()> {
238 self.table.delete(graph)?;
239 Ok(())
240 }
241}
242
243impl generated::inference::HostGraphExecutionContext for WasiNnView<'_> {
244 fn compute(
245 &mut self,
246 exec_context: Resource<GraphExecutionContext>,
247 inputs: Vec<NamedTensor>,
248 ) -> wasmtime::Result<Result<Vec<NamedTensor>, Resource<Error>>> {
249 tracing::debug!("compute with {} inputs", inputs.len());
250
251 let mut named_tensors = Vec::new();
252 for (name, tensor_resopurce) in inputs.iter() {
253 let tensor = self.table.get(&tensor_resopurce)?;
254 named_tensors.push(crate::backend::NamedTensor {
255 name: name.clone(),
256 tensor: tensor.clone(),
257 });
258 }
259
260 let exec_context = &mut self.table.get_mut(&exec_context)?;
261
262 match exec_context.compute_with_io(named_tensors) {
263 Ok(named_tensors) => {
264 let result = named_tensors
265 .into_iter()
266 .map(|crate::backend::NamedTensor { name, tensor }| {
267 self.table.push(tensor).map(|resource| (name, resource))
268 })
269 .collect();
270
271 match result {
272 Ok(tuples) => Ok(Ok(tuples)),
273 Err(error) => {
274 bail!(self, ErrorCode::RuntimeError, error);
275 }
276 }
277 }
278 Err(error) => {
279 bail!(self, ErrorCode::RuntimeError, error);
280 }
281 }
282 }
283
284 fn drop(&mut self, exec_context: Resource<GraphExecutionContext>) -> wasmtime::Result<()> {
285 self.table.delete(exec_context)?;
286 Ok(())
287 }
288}
289
290impl generated::tensor::HostTensor for WasiNnView<'_> {
291 fn new(
292 &mut self,
293 dimensions: TensorDimensions,
294 ty: TensorType,
295 data: TensorData,
296 ) -> wasmtime::Result<Resource<Tensor>> {
297 let tensor = Tensor {
298 dimensions,
299 ty,
300 data,
301 };
302 let tensor = self.table.push(tensor)?;
303 Ok(tensor)
304 }
305
306 fn dimensions(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorDimensions> {
307 let tensor = self.table.get(&tensor)?;
308 Ok(tensor.dimensions.clone())
309 }
310
311 fn ty(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorType> {
312 let tensor = self.table.get(&tensor)?;
313 Ok(tensor.ty)
314 }
315
316 fn data(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<TensorData> {
317 let tensor = self.table.get(&tensor)?;
318 Ok(tensor.data.clone())
319 }
320
321 fn drop(&mut self, tensor: Resource<Tensor>) -> wasmtime::Result<()> {
322 self.table.delete(tensor)?;
323 Ok(())
324 }
325}
326
327impl generated::errors::HostError for WasiNnView<'_> {
328 fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<generated::errors::ErrorCode> {
329 let error = self.table.get(&error)?;
330 match error.code {
331 ErrorCode::InvalidArgument => Ok(generated::errors::ErrorCode::InvalidArgument),
332 ErrorCode::InvalidEncoding => Ok(generated::errors::ErrorCode::InvalidEncoding),
333 ErrorCode::Timeout => Ok(generated::errors::ErrorCode::Timeout),
334 ErrorCode::RuntimeError => Ok(generated::errors::ErrorCode::RuntimeError),
335 ErrorCode::UnsupportedOperation => {
336 Ok(generated::errors::ErrorCode::UnsupportedOperation)
337 }
338 ErrorCode::TooLarge => Ok(generated::errors::ErrorCode::TooLarge),
339 ErrorCode::NotFound => Ok(generated::errors::ErrorCode::NotFound),
340 ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
341 }
342 }
343
344 fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
345 let error = self.table.get(&error)?;
346 Ok(error.data.to_string())
347 }
348
349 fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
350 self.table.delete(error)?;
351 Ok(())
352 }
353}
354
355impl generated::errors::Host for WasiNnView<'_> {
356 fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
357 if matches!(err.code, ErrorCode::Trap) {
358 Err(err.data)
359 } else {
360 Ok(err)
361 }
362 }
363}
364
365impl generated::tensor::Host for WasiNnView<'_> {}
366impl generated::inference::Host for WasiNnView<'_> {}
367
368impl Hash for generated::graph::GraphEncoding {
369 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
370 self.to_string().hash(state)
371 }
372}
373
374impl fmt::Display for generated::graph::GraphEncoding {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 use generated::graph::GraphEncoding::*;
377 match self {
378 Openvino => write!(f, "openvino"),
379 Onnx => write!(f, "onnx"),
380 Pytorch => write!(f, "pytorch"),
381 Tensorflow => write!(f, "tensorflow"),
382 Tensorflowlite => write!(f, "tensorflowlite"),
383 Autodetect => write!(f, "autodetect"),
384 Ggml => write!(f, "ggml"),
385 }
386 }
387}
388
389impl FromStr for generated::graph::GraphEncoding {
390 type Err = GraphEncodingParseError;
391 fn from_str(s: &str) -> Result<Self, Self::Err> {
392 match s.to_lowercase().as_str() {
393 "openvino" => Ok(generated::graph::GraphEncoding::Openvino),
394 "onnx" => Ok(generated::graph::GraphEncoding::Onnx),
395 "pytorch" => Ok(generated::graph::GraphEncoding::Pytorch),
396 "tensorflow" => Ok(generated::graph::GraphEncoding::Tensorflow),
397 "tensorflowlite" => Ok(generated::graph::GraphEncoding::Tensorflowlite),
398 "autodetect" => Ok(generated::graph::GraphEncoding::Autodetect),
399 _ => Err(GraphEncodingParseError(s.into())),
400 }
401 }
402}
403#[derive(Debug)]
404pub struct GraphEncodingParseError(String);
405impl fmt::Display for GraphEncodingParseError {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 write!(f, "unknown graph encoding: {}", self.0)
408 }
409}
410impl std::error::Error for GraphEncodingParseError {}