wasmtime_wasi_nn/backend/
openvino.rs1use super::{
4 read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
5};
6use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
7use crate::{ExecutionContext, Graph};
8use openvino::{DeviceType, ElementType, InferenceError, SetupError, Shape, Tensor as OvTensor};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12#[derive(Default)]
13pub struct OpenvinoBackend(Option<openvino::Core>);
14unsafe impl Send for OpenvinoBackend {}
15unsafe impl Sync for OpenvinoBackend {}
16
17impl BackendInner for OpenvinoBackend {
18 fn encoding(&self) -> GraphEncoding {
19 GraphEncoding::Openvino
20 }
21
22 fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
23 if builders.len() != 2 {
24 return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into());
25 }
26 if self.0.is_none() {
31 self.0.replace(openvino::Core::new()?);
32 }
33 let xml = builders[0];
35 let weights = builders[1];
36
37 let shape = Shape::new(&[1, weights.len() as i64])?;
39 let mut weights_tensor = OvTensor::new(ElementType::U8, &shape)?;
40 let buffer = weights_tensor.get_raw_data_mut()?;
41 buffer.copy_from_slice(&weights);
42
43 let core = self
46 .0
47 .as_mut()
48 .expect("openvino::Core was previously constructed");
49 let model = core.read_model_from_buffer(&xml, Some(&weights_tensor))?;
50 let compiled_model = core.compile_model(&model, target.into())?;
51 let box_: Box<dyn BackendGraph> =
52 Box::new(OpenvinoGraph(Arc::new(Mutex::new(compiled_model))));
53 Ok(box_.into())
54 }
55
56 fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
57 Some(self)
58 }
59}
60
61impl BackendFromDir for OpenvinoBackend {
62 fn load_from_dir(
63 &mut self,
64 path: &Path,
65 target: ExecutionTarget,
66 ) -> Result<Graph, BackendError> {
67 let model = read(&path.join("model.xml"))?;
68 let weights = read(&path.join("model.bin"))?;
69 self.load(&[&model, &weights], target)
70 }
71}
72
73struct OpenvinoGraph(Arc<Mutex<openvino::CompiledModel>>);
74
75unsafe impl Send for OpenvinoGraph {}
76unsafe impl Sync for OpenvinoGraph {}
77
78impl BackendGraph for OpenvinoGraph {
79 fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
80 let mut compiled_model = self.0.lock().unwrap();
81 let infer_request = compiled_model.create_infer_request()?;
82 let box_: Box<dyn BackendExecutionContext> =
83 Box::new(OpenvinoExecutionContext(infer_request));
84 Ok(box_.into())
85 }
86}
87
88struct OpenvinoExecutionContext(openvino::InferRequest);
89
90impl BackendExecutionContext for OpenvinoExecutionContext {
91 fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
92 let precision = tensor.ty.into();
94 let dimensions = tensor
95 .dimensions
96 .iter()
97 .map(|&d| d as i64)
98 .collect::<Vec<_>>();
99 let shape = Shape::new(&dimensions)?;
100 let mut new_tensor = OvTensor::new(precision, &shape)?;
101 let buffer = new_tensor.get_raw_data_mut()?;
102 buffer.copy_from_slice(&tensor.data);
103 match id {
105 Id::Index(i) => self.0.set_input_tensor_by_index(i as usize, &new_tensor)?,
106 Id::Name(name) => self.0.set_tensor(&name, &new_tensor)?,
107 };
108 Ok(())
109 }
110
111 fn compute(&mut self) -> Result<(), BackendError> {
112 self.0.infer()?;
113 Ok(())
114 }
115
116 fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
117 let output_name = match id {
118 Id::Index(i) => self.0.get_output_tensor_by_index(i as usize)?,
119 Id::Name(name) => self.0.get_tensor(&name)?,
120 };
121 let dimensions = output_name
122 .get_shape()?
123 .get_dimensions()
124 .iter()
125 .map(|&dim| dim as u32)
126 .collect::<Vec<u32>>();
127 let ty = output_name.get_element_type()?.try_into()?;
128 let data = output_name.get_raw_data()?.to_vec();
129 Ok(Tensor {
130 dimensions,
131 ty,
132 data,
133 })
134 }
135}
136
137impl From<InferenceError> for BackendError {
138 fn from(e: InferenceError) -> Self {
139 BackendError::BackendAccess(anyhow::Error::new(e))
140 }
141}
142
143impl From<SetupError> for BackendError {
144 fn from(e: SetupError) -> Self {
145 BackendError::BackendAccess(anyhow::Error::new(e))
146 }
147}
148
149impl From<ExecutionTarget> for DeviceType<'static> {
152 fn from(target: ExecutionTarget) -> Self {
153 match target {
154 ExecutionTarget::Cpu => DeviceType::CPU,
155 ExecutionTarget::Gpu => DeviceType::GPU,
156 ExecutionTarget::Tpu => {
157 unimplemented!("OpenVINO does not support TPU execution targets")
158 }
159 }
160 }
161}
162
163impl From<TensorType> for ElementType {
166 fn from(tensor_type: TensorType) -> Self {
167 match tensor_type {
168 TensorType::Fp16 => ElementType::F16,
169 TensorType::Fp32 => ElementType::F32,
170 TensorType::Fp64 => ElementType::F64,
171 TensorType::U8 => ElementType::U8,
172 TensorType::I32 => ElementType::I32,
173 TensorType::I64 => ElementType::I64,
174 TensorType::Bf16 => ElementType::Bf16,
175 }
176 }
177}
178
179impl TryFrom<ElementType> for TensorType {
181 type Error = BackendError;
182 fn try_from(element_type: ElementType) -> Result<Self, Self::Error> {
183 match element_type {
184 ElementType::F16 => Ok(TensorType::Fp16),
185 ElementType::F32 => Ok(TensorType::Fp32),
186 ElementType::F64 => Ok(TensorType::Fp64),
187 ElementType::U8 => Ok(TensorType::U8),
188 ElementType::I32 => Ok(TensorType::I32),
189 ElementType::I64 => Ok(TensorType::I64),
190 ElementType::Bf16 => Ok(TensorType::Bf16),
191 _ => Err(BackendError::UnsupportedTensorType(
192 element_type.to_string(),
193 )),
194 }
195 }
196}