wasmtime_wasi_nn/backend/
mod.rs1#[cfg(feature = "onnx")]
6pub mod onnx;
7#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
8pub mod openvino;
9#[cfg(feature = "pytorch")]
10pub mod pytorch;
11#[cfg(all(feature = "winml", target_os = "windows"))]
12pub mod winml;
13
14#[cfg(feature = "onnx")]
15use self::onnx::OnnxBackend;
16#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
17use self::openvino::OpenvinoBackend;
18#[cfg(feature = "pytorch")]
19use self::pytorch::PytorchBackend;
20#[cfg(all(feature = "winml", target_os = "windows"))]
21use self::winml::WinMLBackend;
22
23use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
24use crate::{Backend, ExecutionContext, Graph};
25use std::fs::File;
26use std::io::Read;
27use std::path::Path;
28use thiserror::Error;
29use wiggle::GuestError;
30
31pub fn list() -> Vec<Backend> {
33 let mut backends = vec![];
34 let _ = &mut backends; #[cfg(all(feature = "openvino", target_pointer_width = "64"))]
36 {
37 backends.push(Backend::from(OpenvinoBackend::default()));
38 }
39 #[cfg(all(feature = "winml", target_os = "windows"))]
40 {
41 backends.push(Backend::from(WinMLBackend::default()));
42 }
43 #[cfg(feature = "onnx")]
44 {
45 backends.push(Backend::from(OnnxBackend::default()));
46 }
47 #[cfg(feature = "pytorch")]
48 {
49 backends.push(Backend::from(PytorchBackend::default()));
50 }
51 backends
52}
53
54pub trait BackendInner: Send + Sync {
56 fn encoding(&self) -> GraphEncoding;
57 fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
58 fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
59}
60
61pub trait BackendFromDir: BackendInner {
65 fn load_from_dir(
66 &mut self,
67 builders: &Path,
68 target: ExecutionTarget,
69 ) -> Result<Graph, BackendError>;
70}
71
72pub trait BackendGraph: Send + Sync {
75 fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
76}
77
78pub trait BackendExecutionContext: Send + Sync {
81 fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
82 fn compute(&mut self) -> Result<(), BackendError>;
83 fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
84}
85
86#[derive(Debug)]
88pub enum Id {
89 Index(u32),
90 Name(String),
91}
92impl Id {
93 pub fn index(&self) -> Option<u32> {
94 match self {
95 Id::Index(i) => Some(*i),
96 Id::Name(_) => None,
97 }
98 }
99 pub fn name(&self) -> Option<&str> {
100 match self {
101 Id::Index(_) => None,
102 Id::Name(n) => Some(n),
103 }
104 }
105}
106
107#[derive(Debug, Error)]
110pub enum BackendError {
111 #[error("Failed while accessing backend")]
112 BackendAccess(#[from] anyhow::Error),
113 #[error("Failed while accessing guest module")]
114 GuestAccess(#[from] GuestError),
115 #[error("The backend expects {0} buffers, passed {1}")]
116 InvalidNumberOfBuilders(usize, usize),
117 #[error("Not enough memory to copy tensor data of size: {0}")]
118 NotEnoughMemory(usize),
119 #[error("Unsupported tensor type: {0}")]
120 UnsupportedTensorType(String),
121}
122
123#[allow(dead_code, reason = "not used on all platforms")]
125fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
126 let mut file = File::open(path)?;
127 let mut buffer = vec![];
128 file.read_to_end(&mut buffer)?;
129 Ok(buffer)
130}