wasmtime_wasi_nn/backend/
mod.rs

1//! Define the Rust interface a backend must implement in order to be used by
2//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
3//! implementations to maintain backend-specific state between calls.
4
5#[cfg(feature = "onnx")]
6pub mod onnx;
7#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
8pub mod openvino;
9#[cfg(feature = "pytorch")]
10pub mod pytorch;
11#[cfg(all(feature = "winml", target_os = "windows"))]
12pub mod winml;
13
14#[cfg(feature = "onnx")]
15use self::onnx::OnnxBackend;
16#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
17use self::openvino::OpenvinoBackend;
18#[cfg(feature = "pytorch")]
19use self::pytorch::PytorchBackend;
20#[cfg(all(feature = "winml", target_os = "windows"))]
21use self::winml::WinMLBackend;
22
23use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
24use crate::{Backend, ExecutionContext, Graph};
25use std::fs::File;
26use std::io::Read;
27use std::path::Path;
28use thiserror::Error;
29use wiggle::GuestError;
30
31/// Return a list of all available backend frameworks.
32pub fn list() -> Vec<Backend> {
33    let mut backends = vec![];
34    let _ = &mut backends; // silence warnings if none are enabled
35    #[cfg(all(feature = "openvino", target_pointer_width = "64"))]
36    {
37        backends.push(Backend::from(OpenvinoBackend::default()));
38    }
39    #[cfg(all(feature = "winml", target_os = "windows"))]
40    {
41        backends.push(Backend::from(WinMLBackend::default()));
42    }
43    #[cfg(feature = "onnx")]
44    {
45        backends.push(Backend::from(OnnxBackend::default()));
46    }
47    #[cfg(feature = "pytorch")]
48    {
49        backends.push(Backend::from(PytorchBackend::default()));
50    }
51    backends
52}
53
54/// A [Backend] contains the necessary state to load [Graph]s.
55pub trait BackendInner: Send + Sync {
56    fn encoding(&self) -> GraphEncoding;
57    fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
58    fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
59}
60
61/// Some [Backend]s support loading a [Graph] from a directory on the
62/// filesystem; this is not a general requirement for backends but is useful for
63/// the Wasmtime CLI.
64pub trait BackendFromDir: BackendInner {
65    fn load_from_dir(
66        &mut self,
67        builders: &Path,
68        target: ExecutionTarget,
69    ) -> Result<Graph, BackendError>;
70}
71
72/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
73/// implementation for the user-facing graph.
74pub trait BackendGraph: Send + Sync {
75    fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
76}
77
78/// A [BackendExecutionContext] performs the actual inference; this is the
79/// backing implementation for a user-facing execution context.
80pub trait BackendExecutionContext: Send + Sync {
81    fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
82    fn compute(&mut self) -> Result<(), BackendError>;
83    fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
84}
85
86/// An identifier for a tensor in a [Graph].
87#[derive(Debug)]
88pub enum Id {
89    Index(u32),
90    Name(String),
91}
92impl Id {
93    pub fn index(&self) -> Option<u32> {
94        match self {
95            Id::Index(i) => Some(*i),
96            Id::Name(_) => None,
97        }
98    }
99    pub fn name(&self) -> Option<&str> {
100        match self {
101            Id::Index(_) => None,
102            Id::Name(n) => Some(n),
103        }
104    }
105}
106
107/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all
108/// for failures interacting with the ML library.
109#[derive(Debug, Error)]
110pub enum BackendError {
111    #[error("Failed while accessing backend")]
112    BackendAccess(#[from] anyhow::Error),
113    #[error("Failed while accessing guest module")]
114    GuestAccess(#[from] GuestError),
115    #[error("The backend expects {0} buffers, passed {1}")]
116    InvalidNumberOfBuilders(usize, usize),
117    #[error("Not enough memory to copy tensor data of size: {0}")]
118    NotEnoughMemory(usize),
119    #[error("Unsupported tensor type: {0}")]
120    UnsupportedTensorType(String),
121}
122
123/// Read a file into a byte vector.
124#[allow(dead_code, reason = "not used on all platforms")]
125fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
126    let mut file = File::open(path)?;
127    let mut buffer = vec![];
128    file.read_to_end(&mut buffer)?;
129    Ok(buffer)
130}