wasmtime_wasi_nn/backend/
openvino.rs1use super::{
4 BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
5 NamedTensor, read,
6};
7use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
8use crate::{ExecutionContext, Graph};
9use openvino::{DeviceType, ElementType, InferenceError, SetupError, Shape, Tensor as OvTensor};
10use std::path::Path;
11use std::sync::{Arc, Mutex};
12
13#[derive(Default)]
14pub struct OpenvinoBackend(Option<openvino::Core>);
15unsafe impl Send for OpenvinoBackend {}
16unsafe impl Sync for OpenvinoBackend {}
17
18impl BackendInner for OpenvinoBackend {
19 fn encoding(&self) -> GraphEncoding {
20 GraphEncoding::Openvino
21 }
22
23 fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
24 if builders.len() != 2 {
25 return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()));
26 }
27 if self.0.is_none() {
32 self.0.replace(openvino::Core::new()?);
33 }
34 let xml = builders[0];
36 let weights = builders[1];
37
38 let shape = Shape::new(&[1, weights.len() as i64])?;
40 let mut weights_tensor = OvTensor::new(ElementType::U8, &shape)?;
41 let buffer = weights_tensor.get_raw_data_mut()?;
42 buffer.copy_from_slice(&weights);
43
44 let core = self
47 .0
48 .as_mut()
49 .expect("openvino::Core was previously constructed");
50 let model = core.read_model_from_buffer(&xml, Some(&weights_tensor))?;
51 let compiled_model = core.compile_model(&model, target.into())?;
52 let box_: Box<dyn BackendGraph> =
53 Box::new(OpenvinoGraph(Arc::new(Mutex::new(compiled_model))));
54 Ok(box_.into())
55 }
56
57 fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
58 Some(self)
59 }
60}
61
62impl BackendFromDir for OpenvinoBackend {
63 fn load_from_dir(
64 &mut self,
65 path: &Path,
66 target: ExecutionTarget,
67 ) -> Result<Graph, BackendError> {
68 let model = read(&path.join("model.xml"))?;
69 let weights = read(&path.join("model.bin"))?;
70 self.load(&[&model, &weights], target)
71 }
72}
73
74struct OpenvinoGraph(Arc<Mutex<openvino::CompiledModel>>);
75
76unsafe impl Send for OpenvinoGraph {}
77unsafe impl Sync for OpenvinoGraph {}
78
79impl BackendGraph for OpenvinoGraph {
80 fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
81 let mut compiled_model = self.0.lock().unwrap();
82 let infer_request = compiled_model.create_infer_request()?;
83 let box_: Box<dyn BackendExecutionContext> =
84 Box::new(OpenvinoExecutionContext(infer_request, self.0.clone()));
85 Ok(box_.into())
86 }
87}
88
89struct OpenvinoExecutionContext(openvino::InferRequest, Arc<Mutex<openvino::CompiledModel>>);
90
91impl BackendExecutionContext for OpenvinoExecutionContext {
92 fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
93 let precision = tensor.ty.into();
95 let dimensions = tensor
96 .dimensions
97 .iter()
98 .map(|&d| d as i64)
99 .collect::<Vec<_>>();
100 let shape = Shape::new(&dimensions)?;
101 let mut new_tensor = OvTensor::new(precision, &shape)?;
102 let buffer = new_tensor.get_raw_data_mut()?;
103 buffer.copy_from_slice(&tensor.data);
104 match id {
106 Id::Index(i) => self.0.set_input_tensor_by_index(i as usize, &new_tensor)?,
107 Id::Name(name) => self.0.set_tensor(&name, &new_tensor)?,
108 };
109 Ok(())
110 }
111
112 fn compute(
113 &mut self,
114 inputs: Option<Vec<NamedTensor>>,
115 ) -> Result<Option<Vec<NamedTensor>>, BackendError> {
116 match inputs {
117 Some(inputs) => {
119 for input in &inputs {
121 let precision = input.tensor.ty.into();
122 let dimensions = input
123 .tensor
124 .dimensions
125 .iter()
126 .map(|&d| d as i64)
127 .collect::<Vec<_>>();
128 let shape = Shape::new(&dimensions)?;
129 let mut new_tensor = OvTensor::new(precision, &shape)?;
130 let buffer = new_tensor.get_raw_data_mut()?;
131 buffer.copy_from_slice(&input.tensor.data);
132
133 self.0.set_tensor(&input.name, &new_tensor)?;
134 }
135
136 self.0.infer()?;
138
139 let compiled_model = self.1.lock().unwrap();
141 let output_count = compiled_model.get_output_size()?;
142
143 let mut output_tensors = Vec::new();
144 for i in 0..output_count {
145 let output_tensor = self.0.get_output_tensor_by_index(i)?;
146
147 let dimensions = output_tensor
148 .get_shape()?
149 .get_dimensions()
150 .iter()
151 .map(|&dim| dim as u32)
152 .collect::<Vec<u32>>();
153
154 let ty = output_tensor.get_element_type()?.try_into()?;
155 let data = output_tensor.get_raw_data()?.to_vec();
156
157 output_tensors.push(NamedTensor {
159 name: format!("{i}"),
160 tensor: Tensor {
161 dimensions,
162 ty,
163 data,
164 },
165 });
166 }
167 Ok(Some(output_tensors))
168 }
169
170 None => {
172 self.0.infer()?;
173 Ok(None)
174 }
175 }
176 }
177
178 fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
179 let output_name = match id {
180 Id::Index(i) => self.0.get_output_tensor_by_index(i as usize)?,
181 Id::Name(name) => self.0.get_tensor(&name)?,
182 };
183 let dimensions = output_name
184 .get_shape()?
185 .get_dimensions()
186 .iter()
187 .map(|&dim| dim as u32)
188 .collect::<Vec<u32>>();
189 let ty = output_name.get_element_type()?.try_into()?;
190 let data = output_name.get_raw_data()?.to_vec();
191 Ok(Tensor {
192 dimensions,
193 ty,
194 data,
195 })
196 }
197}
198
199impl From<InferenceError> for BackendError {
200 fn from(e: InferenceError) -> Self {
201 BackendError::BackendAccess(anyhow::Error::new(e))
202 }
203}
204
205impl From<SetupError> for BackendError {
206 fn from(e: SetupError) -> Self {
207 BackendError::BackendAccess(anyhow::Error::new(e))
208 }
209}
210
211impl From<ExecutionTarget> for DeviceType<'static> {
214 fn from(target: ExecutionTarget) -> Self {
215 match target {
216 ExecutionTarget::Cpu => DeviceType::CPU,
217 ExecutionTarget::Gpu => DeviceType::GPU,
218 ExecutionTarget::Tpu => {
219 unimplemented!("OpenVINO does not support TPU execution targets")
220 }
221 }
222 }
223}
224
225impl From<TensorType> for ElementType {
228 fn from(tensor_type: TensorType) -> Self {
229 match tensor_type {
230 TensorType::Fp16 => ElementType::F16,
231 TensorType::Fp32 => ElementType::F32,
232 TensorType::Fp64 => ElementType::F64,
233 TensorType::U8 => ElementType::U8,
234 TensorType::I32 => ElementType::I32,
235 TensorType::I64 => ElementType::I64,
236 TensorType::Bf16 => ElementType::Bf16,
237 }
238 }
239}
240
241impl TryFrom<ElementType> for TensorType {
243 type Error = BackendError;
244 fn try_from(element_type: ElementType) -> Result<Self, Self::Error> {
245 match element_type {
246 ElementType::F16 => Ok(TensorType::Fp16),
247 ElementType::F32 => Ok(TensorType::Fp32),
248 ElementType::F64 => Ok(TensorType::Fp64),
249 ElementType::U8 => Ok(TensorType::U8),
250 ElementType::I32 => Ok(TensorType::I32),
251 ElementType::I64 => Ok(TensorType::I64),
252 ElementType::Bf16 => Ok(TensorType::Bf16),
253 _ => Err(BackendError::UnsupportedTensorType(
254 element_type.to_string(),
255 )),
256 }
257 }
258}