wasmtime_wasi_nn/backend/
openvino.rs

1//! Implements a `wasi-nn` [`BackendInner`] using OpenVINO.
2
3use super::{
4    read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id,
5};
6use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
7use crate::{ExecutionContext, Graph};
8use openvino::{DeviceType, ElementType, InferenceError, SetupError, Shape, Tensor as OvTensor};
9use std::path::Path;
10use std::sync::{Arc, Mutex};
11
12#[derive(Default)]
13pub struct OpenvinoBackend(Option<openvino::Core>);
14unsafe impl Send for OpenvinoBackend {}
15unsafe impl Sync for OpenvinoBackend {}
16
17impl BackendInner for OpenvinoBackend {
18    fn encoding(&self) -> GraphEncoding {
19        GraphEncoding::Openvino
20    }
21
22    fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> {
23        if builders.len() != 2 {
24            return Err(BackendError::InvalidNumberOfBuilders(2, builders.len()).into());
25        }
26        // Construct the context if none is present; this is done lazily (i.e.
27        // upon actually loading a model) because it may fail to find and load
28        // the OpenVINO libraries. The laziness limits the extent of the error
29        // only to wasi-nn users, not all WASI users.
30        if self.0.is_none() {
31            self.0.replace(openvino::Core::new()?);
32        }
33        // Read the guest array.
34        let xml = builders[0];
35        let weights = builders[1];
36
37        // Construct a new tensor for the model weights.
38        let shape = Shape::new(&[1, weights.len() as i64])?;
39        let mut weights_tensor = OvTensor::new(ElementType::U8, &shape)?;
40        let buffer = weights_tensor.get_raw_data_mut()?;
41        buffer.copy_from_slice(&weights);
42
43        // Construct OpenVINO graph structures: `model` contains the graph
44        // structure, `compiled_model` can perform inference.
45        let core = self
46            .0
47            .as_mut()
48            .expect("openvino::Core was previously constructed");
49        let model = core.read_model_from_buffer(&xml, Some(&weights_tensor))?;
50        let compiled_model = core.compile_model(&model, target.into())?;
51        let box_: Box<dyn BackendGraph> =
52            Box::new(OpenvinoGraph(Arc::new(Mutex::new(compiled_model))));
53        Ok(box_.into())
54    }
55
56    fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> {
57        Some(self)
58    }
59}
60
61impl BackendFromDir for OpenvinoBackend {
62    fn load_from_dir(
63        &mut self,
64        path: &Path,
65        target: ExecutionTarget,
66    ) -> Result<Graph, BackendError> {
67        let model = read(&path.join("model.xml"))?;
68        let weights = read(&path.join("model.bin"))?;
69        self.load(&[&model, &weights], target)
70    }
71}
72
73struct OpenvinoGraph(Arc<Mutex<openvino::CompiledModel>>);
74
75unsafe impl Send for OpenvinoGraph {}
76unsafe impl Sync for OpenvinoGraph {}
77
78impl BackendGraph for OpenvinoGraph {
79    fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
80        let mut compiled_model = self.0.lock().unwrap();
81        let infer_request = compiled_model.create_infer_request()?;
82        let box_: Box<dyn BackendExecutionContext> =
83            Box::new(OpenvinoExecutionContext(infer_request));
84        Ok(box_.into())
85    }
86}
87
88struct OpenvinoExecutionContext(openvino::InferRequest);
89
90impl BackendExecutionContext for OpenvinoExecutionContext {
91    fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
92        // Construct the tensor.
93        let precision = tensor.ty.into();
94        let dimensions = tensor
95            .dimensions
96            .iter()
97            .map(|&d| d as i64)
98            .collect::<Vec<_>>();
99        let shape = Shape::new(&dimensions)?;
100        let mut new_tensor = OvTensor::new(precision, &shape)?;
101        let buffer = new_tensor.get_raw_data_mut()?;
102        buffer.copy_from_slice(&tensor.data);
103        // Assign the tensor to the request.
104        match id {
105            Id::Index(i) => self.0.set_input_tensor_by_index(i as usize, &new_tensor)?,
106            Id::Name(name) => self.0.set_tensor(&name, &new_tensor)?,
107        };
108        Ok(())
109    }
110
111    fn compute(&mut self) -> Result<(), BackendError> {
112        self.0.infer()?;
113        Ok(())
114    }
115
116    fn get_output(&mut self, id: Id) -> Result<Tensor, BackendError> {
117        let output_name = match id {
118            Id::Index(i) => self.0.get_output_tensor_by_index(i as usize)?,
119            Id::Name(name) => self.0.get_tensor(&name)?,
120        };
121        let dimensions = output_name
122            .get_shape()?
123            .get_dimensions()
124            .iter()
125            .map(|&dim| dim as u32)
126            .collect::<Vec<u32>>();
127        let ty = output_name.get_element_type()?.try_into()?;
128        let data = output_name.get_raw_data()?.to_vec();
129        Ok(Tensor {
130            dimensions,
131            ty,
132            data,
133        })
134    }
135}
136
137impl From<InferenceError> for BackendError {
138    fn from(e: InferenceError) -> Self {
139        BackendError::BackendAccess(anyhow::Error::new(e))
140    }
141}
142
143impl From<SetupError> for BackendError {
144    fn from(e: SetupError) -> Self {
145        BackendError::BackendAccess(anyhow::Error::new(e))
146    }
147}
148
149/// Return the execution target string expected by OpenVINO from the
150/// `ExecutionTarget` enum provided by wasi-nn.
151impl From<ExecutionTarget> for DeviceType<'static> {
152    fn from(target: ExecutionTarget) -> Self {
153        match target {
154            ExecutionTarget::Cpu => DeviceType::CPU,
155            ExecutionTarget::Gpu => DeviceType::GPU,
156            ExecutionTarget::Tpu => {
157                unimplemented!("OpenVINO does not support TPU execution targets")
158            }
159        }
160    }
161}
162
163/// Return OpenVINO's precision type for the `TensorType` enum provided by
164/// wasi-nn.
165impl From<TensorType> for ElementType {
166    fn from(tensor_type: TensorType) -> Self {
167        match tensor_type {
168            TensorType::Fp16 => ElementType::F16,
169            TensorType::Fp32 => ElementType::F32,
170            TensorType::Fp64 => ElementType::F64,
171            TensorType::U8 => ElementType::U8,
172            TensorType::I32 => ElementType::I32,
173            TensorType::I64 => ElementType::I64,
174            TensorType::Bf16 => ElementType::Bf16,
175        }
176    }
177}
178
179/// Return the `TensorType` enum provided by wasi-nn for OpenVINO's precision type
180impl TryFrom<ElementType> for TensorType {
181    type Error = BackendError;
182    fn try_from(element_type: ElementType) -> Result<Self, Self::Error> {
183        match element_type {
184            ElementType::F16 => Ok(TensorType::Fp16),
185            ElementType::F32 => Ok(TensorType::Fp32),
186            ElementType::F64 => Ok(TensorType::Fp64),
187            ElementType::U8 => Ok(TensorType::U8),
188            ElementType::I32 => Ok(TensorType::I32),
189            ElementType::I64 => Ok(TensorType::I64),
190            ElementType::Bf16 => Ok(TensorType::Bf16),
191            _ => Err(BackendError::UnsupportedTensorType(
192                element_type.to_string(),
193            )),
194        }
195    }
196}