wasmtime/
sync_nostd.rs

1//! Synchronization primitives for Wasmtime for `no_std`.
2//!
3//! These primitives are intended for use in `no_std` contexts are not as
4//! full-featured as the `std` brethren. Namely these panic and/or return an
5//! error on contention. This serves to continue to be correct in the face of
6//! actual multiple threads, but if a system actually has multiple threads then
7//! something will need to change in the Wasmtime crate to enable the external
8//! system to perform necessary synchronization.
9//!
10//! In the future if these primitives are not suitable we can switch to putting
11//! relevant functions in the `capi.rs` module where we basically require
12//! embedders to implement them instead of doing it ourselves here. It's unclear
13//! if this will be necessary, so this is the chosen starting point.
14//!
15//! See a brief overview of this module in `sync_std.rs` as well.
16
17#![cfg_attr(
18    all(feature = "std", not(test)),
19    expect(dead_code, reason = "not used, but typechecked")
20)]
21
22use core::cell::UnsafeCell;
23use core::mem::MaybeUninit;
24use core::ops::{Deref, DerefMut};
25use core::sync::atomic::{AtomicU8, AtomicU32, Ordering};
26
27pub struct OnceLock<T> {
28    val: UnsafeCell<MaybeUninit<T>>,
29    state: AtomicU8,
30}
31
32unsafe impl<T: Send> Send for OnceLock<T> {}
33unsafe impl<T: Sync> Sync for OnceLock<T> {}
34
35const UNINITIALIZED: u8 = 0;
36const INITIALIZING: u8 = 1;
37const INITIALIZED: u8 = 2;
38
39impl<T> OnceLock<T> {
40    pub const fn new() -> OnceLock<T> {
41        OnceLock {
42            state: AtomicU8::new(UNINITIALIZED),
43            val: UnsafeCell::new(MaybeUninit::uninit()),
44        }
45    }
46
47    pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
48        if let Some(ret) = self.get() {
49            return ret;
50        }
51        self.try_init::<()>(|| Ok(f())).unwrap()
52    }
53
54    pub fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
55        if let Some(ret) = self.get() {
56            return Ok(ret);
57        }
58        self.try_init(f)
59    }
60
61    fn get(&self) -> Option<&T> {
62        if self.state.load(Ordering::Acquire) == INITIALIZED {
63            Some(unsafe { (*self.val.get()).assume_init_ref() })
64        } else {
65            None
66        }
67    }
68
69    #[cold]
70    fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
71        match self.state.compare_exchange(
72            UNINITIALIZED,
73            INITIALIZING,
74            Ordering::Acquire,
75            Ordering::Acquire,
76        ) {
77            Ok(UNINITIALIZED) => match f() {
78                Ok(val) => {
79                    let ret = unsafe { &*(*self.val.get()).write(val) };
80                    let prev = self.state.swap(INITIALIZED, Ordering::Release);
81                    assert_eq!(prev, INITIALIZING);
82                    Ok(ret)
83                }
84                Err(e) => match self.state.swap(UNINITIALIZED, Ordering::Release) {
85                    INITIALIZING => Err(e),
86                    _ => unreachable!(),
87                },
88            },
89            Err(INITIALIZING) => panic!("concurrent initialization only allowed with `std`"),
90            Err(INITIALIZED) => Ok(self.get().unwrap()),
91            _ => unreachable!(),
92        }
93    }
94}
95
96impl<T> Default for OnceLock<T> {
97    fn default() -> OnceLock<T> {
98        OnceLock::new()
99    }
100}
101
102#[derive(Debug, Default)]
103pub struct RwLock<T> {
104    val: UnsafeCell<T>,
105    state: AtomicU32,
106}
107
108unsafe impl<T: Send> Send for RwLock<T> {}
109unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
110
111impl<T> RwLock<T> {
112    pub const fn new(val: T) -> RwLock<T> {
113        RwLock {
114            val: UnsafeCell::new(val),
115            state: AtomicU32::new(0),
116        }
117    }
118
119    pub fn read(&self) -> impl Deref<Target = T> + '_ {
120        const READER_LIMIT: u32 = u32::MAX / 2;
121        match self
122            .state
123            .fetch_update(Ordering::Acquire, Ordering::Acquire, |x| match x {
124                u32::MAX => None,
125                n => {
126                    let next = n + 1;
127                    if next < READER_LIMIT {
128                        Some(next)
129                    } else {
130                        None
131                    }
132                }
133            }) {
134            Ok(_) => RwLockReadGuard { lock: self },
135            Err(_) => panic!(
136                "concurrent read request while locked for writing, must use `std` to avoid panic"
137            ),
138        }
139    }
140
141    pub fn write(&self) -> impl DerefMut<Target = T> + '_ {
142        match self
143            .state
144            .compare_exchange(0, u32::MAX, Ordering::Acquire, Ordering::Relaxed)
145        {
146            Ok(0) => RwLockWriteGuard { lock: self },
147            _ => panic!("concurrent write request, must use `std` to avoid panicking"),
148        }
149    }
150}
151
152struct RwLockReadGuard<'a, T> {
153    lock: &'a RwLock<T>,
154}
155
156impl<T> Deref for RwLockReadGuard<'_, T> {
157    type Target = T;
158
159    fn deref(&self) -> &T {
160        unsafe { &*self.lock.val.get() }
161    }
162}
163
164impl<T> Drop for RwLockReadGuard<'_, T> {
165    fn drop(&mut self) {
166        self.lock.state.fetch_sub(1, Ordering::Release);
167    }
168}
169
170struct RwLockWriteGuard<'a, T> {
171    lock: &'a RwLock<T>,
172}
173
174impl<T> Deref for RwLockWriteGuard<'_, T> {
175    type Target = T;
176
177    fn deref(&self) -> &T {
178        unsafe { &*self.lock.val.get() }
179    }
180}
181
182impl<T> DerefMut for RwLockWriteGuard<'_, T> {
183    fn deref_mut(&mut self) -> &mut T {
184        unsafe { &mut *self.lock.val.get() }
185    }
186}
187
188impl<T> Drop for RwLockWriteGuard<'_, T> {
189    fn drop(&mut self) {
190        match self.lock.state.swap(0, Ordering::Release) {
191            u32::MAX => {}
192            _ => unreachable!(),
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn smoke_once_lock() {
203        let lock = OnceLock::new();
204        assert!(lock.get().is_none());
205        assert_eq!(*lock.get_or_init(|| 1), 1);
206        assert_eq!(*lock.get_or_init(|| 2), 1);
207        assert_eq!(*lock.get_or_init(|| 3), 1);
208        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&1));
209
210        let lock = OnceLock::new();
211        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&3));
212        assert_eq!(*lock.get_or_init(|| 1), 3);
213
214        let lock = OnceLock::new();
215        assert_eq!(lock.get_or_try_init(|| Err(())), Err(()));
216        assert_eq!(*lock.get_or_init(|| 1), 1);
217    }
218
219    #[test]
220    fn smoke_rwlock() {
221        let lock = RwLock::new(1);
222        assert_eq!(*lock.read(), 1);
223
224        let a = lock.read();
225        let b = lock.read();
226        assert_eq!(*a, 1);
227        assert_eq!(*b, 1);
228        drop((a, b));
229
230        assert_eq!(*lock.write(), 1);
231
232        *lock.write() = 4;
233        assert_eq!(*lock.read(), 4);
234        assert_eq!(*lock.write(), 4);
235
236        let a = lock.read();
237        let b = lock.read();
238        assert_eq!(*a, 4);
239        assert_eq!(*b, 4);
240        drop((a, b));
241    }
242
243    #[test]
244    #[should_panic(expected = "concurrent write request")]
245    fn rwlock_panic_read_then_write() {
246        let lock = RwLock::new(1);
247        let _a = lock.read();
248        let _b = lock.write();
249    }
250
251    #[test]
252    #[should_panic(expected = "concurrent read request")]
253    fn rwlock_panic_write_then_read() {
254        let lock = RwLock::new(1);
255        let _a = lock.write();
256        let _b = lock.read();
257    }
258
259    #[test]
260    #[should_panic(expected = "concurrent write request")]
261    fn rwlock_panic_write_then_write() {
262        let lock = RwLock::new(1);
263        let _a = lock.write();
264        let _b = lock.write();
265    }
266}