wasmtime_wasi_nn/backend/
mod.rs

1//! Define the Rust interface a backend must implement in order to be used by
2//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
3//! implementations to maintain backend-specific state between calls.
4
5#[cfg(feature = "onnx")]
6pub mod onnx;
7#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
8pub mod openvino;
9#[cfg(feature = "pytorch")]
10pub mod pytorch;
11#[cfg(all(feature = "winml", target_os = "windows"))]
12pub mod winml;
13
14#[cfg(feature = "onnx")]
15use self::onnx::OnnxBackend;
16#[cfg(all(feature = "openvino", target_pointer_width = "64"))]
17use self::openvino::OpenvinoBackend;
18#[cfg(feature = "pytorch")]
19use self::pytorch::PytorchBackend;
20#[cfg(all(feature = "winml", target_os = "windows"))]
21use self::winml::WinMLBackend;
22
23use crate::wit::{ExecutionTarget, GraphEncoding, Tensor};
24use crate::{Backend, ExecutionContext, Graph};
25use std::fs::File;
26use std::io::Read;
27use std::path::Path;
28use thiserror::Error;
29use wiggle::GuestError;
30
31/// Return a list of all available backend frameworks.
32pub fn list() -> Vec<Backend> {
33    let mut backends = vec![];
34    let _ = &mut backends; // silence warnings if none are enabled
35    #[cfg(all(feature = "openvino", target_pointer_width = "64"))]
36    {
37        backends.push(Backend::from(OpenvinoBackend::default()));
38    }
39    #[cfg(all(feature = "winml", target_os = "windows"))]
40    {
41        backends.push(Backend::from(WinMLBackend::default()));
42    }
43    #[cfg(feature = "onnx")]
44    {
45        backends.push(Backend::from(OnnxBackend::default()));
46    }
47    #[cfg(feature = "pytorch")]
48    {
49        backends.push(Backend::from(PytorchBackend::default()));
50    }
51    backends
52}
53
54/// A [Backend] contains the necessary state to load [Graph]s.
55pub trait BackendInner: Send + Sync {
56    fn encoding(&self) -> GraphEncoding;
57    fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError>;
58    fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>;
59}
60
61/// Some [Backend]s support loading a [Graph] from a directory on the
62/// filesystem; this is not a general requirement for backends but is useful for
63/// the Wasmtime CLI.
64pub trait BackendFromDir: BackendInner {
65    fn load_from_dir(
66        &mut self,
67        builders: &Path,
68        target: ExecutionTarget,
69    ) -> Result<Graph, BackendError>;
70}
71
72/// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing
73/// implementation for the user-facing graph.
74pub trait BackendGraph: Send + Sync {
75    fn init_execution_context(&self) -> Result<ExecutionContext, BackendError>;
76}
77
78/// A [BackendExecutionContext] performs the actual inference; this is the
79/// backing implementation for a user-facing execution context.
80pub trait BackendExecutionContext: Send + Sync {
81    // WITX functions
82    fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>;
83    fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError>;
84
85    // Functions which work for both WIT and WITX
86    fn compute(
87        &mut self,
88        inputs: Option<Vec<NamedTensor>>,
89    ) -> Result<Option<Vec<NamedTensor>>, BackendError>;
90}
91
92/// An identifier for a tensor in a [Graph].
93#[derive(Debug)]
94pub enum Id {
95    Index(u32),
96    Name(String),
97}
98impl Id {
99    pub fn index(&self) -> Option<u32> {
100        match self {
101            Id::Index(i) => Some(*i),
102            Id::Name(_) => None,
103        }
104    }
105    pub fn name(&self) -> Option<&str> {
106        match self {
107            Id::Index(_) => None,
108            Id::Name(n) => Some(n),
109        }
110    }
111}
112
113/// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all
114/// for failures interacting with the ML library.
115#[derive(Debug, Error)]
116pub enum BackendError {
117    #[error("Failed while accessing backend")]
118    BackendAccess(#[from] anyhow::Error),
119    #[error("Failed while accessing guest module")]
120    GuestAccess(#[from] GuestError),
121    #[error("The backend expects {0} buffers, passed {1}")]
122    InvalidNumberOfBuilders(usize, usize),
123    #[error("Not enough memory to copy tensor data of size: {0}")]
124    NotEnoughMemory(usize),
125    #[error("Unsupported tensor type: {0}")]
126    UnsupportedTensorType(String),
127}
128
129/// Read a file into a byte vector.
130#[allow(dead_code, reason = "not used on all platforms")]
131fn read(path: &Path) -> anyhow::Result<Vec<u8>> {
132    let mut file = File::open(path)?;
133    let mut buffer = vec![];
134    file.read_to_end(&mut buffer)?;
135    Ok(buffer)
136}
137
138pub struct NamedTensor {
139    pub name: String,
140    pub tensor: Tensor,
141}