1pub mod backend;
2mod registry;
3pub mod wit;
4pub mod witx;
5
6use crate::backend::{BackendError, Id, NamedTensor as BackendNamedTensor};
7use crate::wit::generated_::wasi::nn::tensor::TensorType;
8use anyhow::anyhow;
9use core::fmt;
10pub use registry::{GraphRegistry, InMemoryRegistry};
11use std::path::Path;
12use std::sync::Arc;
13
14pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec<Backend>, Registry)> {
19 let mut backends = backend::list();
20 let mut registry = InMemoryRegistry::new();
21 for (kind, path) in preload_graphs {
22 let kind_ = kind.parse()?;
23 let backend = backends
24 .iter_mut()
25 .find(|b| b.encoding() == kind_)
26 .ok_or(anyhow!("unsupported backend: {}", kind))?
27 .as_dir_loadable()
28 .ok_or(anyhow!("{} does not support directory loading", kind))?;
29 registry.load(backend, Path::new(path))?;
30 }
31 Ok((backends, Registry::from(registry)))
32}
33
34pub struct Backend(Box<dyn backend::BackendInner>);
36impl std::ops::Deref for Backend {
37 type Target = dyn backend::BackendInner;
38 fn deref(&self) -> &Self::Target {
39 self.0.as_ref()
40 }
41}
42impl std::ops::DerefMut for Backend {
43 fn deref_mut(&mut self) -> &mut Self::Target {
44 self.0.as_mut()
45 }
46}
47impl<T: backend::BackendInner + 'static> From<T> for Backend {
48 fn from(value: T) -> Self {
49 Self(Box::new(value))
50 }
51}
52
53#[derive(Clone)]
55pub struct Graph(Arc<dyn backend::BackendGraph>);
56impl From<Box<dyn backend::BackendGraph>> for Graph {
57 fn from(value: Box<dyn backend::BackendGraph>) -> Self {
58 Self(value.into())
59 }
60}
61impl std::ops::Deref for Graph {
62 type Target = dyn backend::BackendGraph;
63 fn deref(&self) -> &Self::Target {
64 self.0.as_ref()
65 }
66}
67
68#[derive(Clone, PartialEq)]
74pub struct Tensor {
75 pub dimensions: Vec<u32>,
76 pub ty: TensorType,
77 pub data: Vec<u8>,
78}
79impl fmt::Debug for Tensor {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 f.debug_struct("Tensor")
82 .field("dimensions", &self.dimensions)
83 .field("ty", &self.ty)
84 .field("data (bytes)", &self.data.len())
85 .finish()
86 }
87}
88
89pub struct ExecutionContext(Box<dyn backend::BackendExecutionContext>);
91impl From<Box<dyn backend::BackendExecutionContext>> for ExecutionContext {
92 fn from(value: Box<dyn backend::BackendExecutionContext>) -> Self {
93 Self(value)
94 }
95}
96impl std::ops::Deref for ExecutionContext {
97 type Target = dyn backend::BackendExecutionContext;
98 fn deref(&self) -> &Self::Target {
99 self.0.as_ref()
100 }
101}
102impl std::ops::DerefMut for ExecutionContext {
103 fn deref_mut(&mut self) -> &mut Self::Target {
104 self.0.as_mut()
105 }
106}
107
108pub struct Registry(Box<dyn GraphRegistry>);
110impl std::ops::Deref for Registry {
111 type Target = dyn GraphRegistry;
112 fn deref(&self) -> &Self::Target {
113 self.0.as_ref()
114 }
115}
116impl std::ops::DerefMut for Registry {
117 fn deref_mut(&mut self) -> &mut Self::Target {
118 self.0.as_mut()
119 }
120}
121impl<T> From<T> for Registry
122where
123 T: GraphRegistry + 'static,
124{
125 fn from(value: T) -> Self {
126 Self(Box::new(value))
127 }
128}
129
130impl ExecutionContext {
131 pub fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
132 self.0.set_input(id, tensor)
133 }
134
135 pub fn compute(&mut self) -> Result<(), BackendError> {
136 self.0.compute(None).map(|_| ())
137 }
138
139 pub fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
140 self.0.get_output(id)
141 }
142
143 pub fn compute_with_io(
144 &mut self,
145 inputs: Vec<BackendNamedTensor>,
146 ) -> Result<Vec<BackendNamedTensor>, BackendError> {
147 match self.0.compute(Some(inputs))? {
148 Some(outputs) => Ok(outputs),
149 None => Ok(Vec::new()),
150 }
151 }
152}
153
154impl Tensor {
155 pub fn new(dimensions: Vec<u32>, ty: TensorType, data: Vec<u8>) -> Self {
156 Self {
157 dimensions,
158 ty,
159 data,
160 }
161 }
162}