wasmtime/
sync_nostd.rs

1//! Synchronization primitives for Wasmtime for `no_std`.
2//!
3//! These primitives are intended for use in `no_std` contexts are are not as
4//! full-featured as the `std` brethren. Namely these panic and/or return an
5//! error on contention. This serves to continue to be correct in the face of
6//! actual multiple threads, but if a system actually has multiple threads then
7//! something will need to change in the Wasmtime crate to enable the external
8//! system to perform necessary synchronization.
9//!
10//! In the future if these primitives are not suitable we can switch to putting
11//! relevant functions in the `capi.rs` module where we basically require
12//! embedders to implement them instead of doing it ourselves here. It's unclear
13//! if this will be necessary, so this is the chosen starting point.
14//!
15//! See a brief overview of this module in `sync_std.rs` as well.
16
17use core::cell::UnsafeCell;
18use core::mem::MaybeUninit;
19use core::ops::{Deref, DerefMut};
20use core::sync::atomic::{AtomicU32, AtomicU8, Ordering};
21
22pub struct OnceLock<T> {
23    val: UnsafeCell<MaybeUninit<T>>,
24    state: AtomicU8,
25}
26
27unsafe impl<T: Send> Send for OnceLock<T> {}
28unsafe impl<T: Sync> Sync for OnceLock<T> {}
29
30const UNINITIALIZED: u8 = 0;
31const INITIALIZING: u8 = 1;
32const INITIALIZED: u8 = 2;
33
34impl<T> OnceLock<T> {
35    pub const fn new() -> OnceLock<T> {
36        OnceLock {
37            state: AtomicU8::new(UNINITIALIZED),
38            val: UnsafeCell::new(MaybeUninit::uninit()),
39        }
40    }
41
42    pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
43        if let Some(ret) = self.get() {
44            return ret;
45        }
46        self.try_init::<()>(|| Ok(f())).unwrap()
47    }
48
49    pub fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
50        if let Some(ret) = self.get() {
51            return Ok(ret);
52        }
53        self.try_init(f)
54    }
55
56    fn get(&self) -> Option<&T> {
57        if self.state.load(Ordering::Acquire) == INITIALIZED {
58            Some(unsafe { (*self.val.get()).assume_init_ref() })
59        } else {
60            None
61        }
62    }
63
64    #[cold]
65    fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
66        match self.state.compare_exchange(
67            UNINITIALIZED,
68            INITIALIZING,
69            Ordering::Acquire,
70            Ordering::Acquire,
71        ) {
72            Ok(UNINITIALIZED) => match f() {
73                Ok(val) => {
74                    let ret = unsafe { &*(*self.val.get()).write(val) };
75                    let prev = self.state.swap(INITIALIZED, Ordering::Release);
76                    assert_eq!(prev, INITIALIZING);
77                    Ok(ret)
78                }
79                Err(e) => match self.state.swap(UNINITIALIZED, Ordering::Release) {
80                    INITIALIZING => Err(e),
81                    _ => unreachable!(),
82                },
83            },
84            Err(INITIALIZING) => panic!("concurrent initialization only allowed with `std`"),
85            Err(INITIALIZED) => Ok(self.get().unwrap()),
86            _ => unreachable!(),
87        }
88    }
89}
90
91impl<T> Default for OnceLock<T> {
92    fn default() -> OnceLock<T> {
93        OnceLock::new()
94    }
95}
96
97#[derive(Debug, Default)]
98pub struct RwLock<T> {
99    val: UnsafeCell<T>,
100    state: AtomicU32,
101}
102
103unsafe impl<T: Send> Send for RwLock<T> {}
104unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
105
106impl<T> RwLock<T> {
107    pub const fn new(val: T) -> RwLock<T> {
108        RwLock {
109            val: UnsafeCell::new(val),
110            state: AtomicU32::new(0),
111        }
112    }
113
114    pub fn read(&self) -> impl Deref<Target = T> + '_ {
115        const READER_LIMIT: u32 = u32::MAX / 2;
116        match self
117            .state
118            .fetch_update(Ordering::Acquire, Ordering::Acquire, |x| match x {
119                u32::MAX => None,
120                n => {
121                    let next = n + 1;
122                    if next < READER_LIMIT {
123                        Some(next)
124                    } else {
125                        None
126                    }
127                }
128            }) {
129            Ok(_) => RwLockReadGuard { lock: self },
130            Err(_) => panic!(
131                "concurrent read request while locked for writing, must use `std` to avoid panic"
132            ),
133        }
134    }
135
136    pub fn write(&self) -> impl DerefMut<Target = T> + '_ {
137        match self
138            .state
139            .compare_exchange(0, u32::MAX, Ordering::Acquire, Ordering::Relaxed)
140        {
141            Ok(0) => RwLockWriteGuard { lock: self },
142            _ => panic!("concurrent write request, must use `std` to avoid panicking"),
143        }
144    }
145}
146
147struct RwLockReadGuard<'a, T> {
148    lock: &'a RwLock<T>,
149}
150
151impl<T> Deref for RwLockReadGuard<'_, T> {
152    type Target = T;
153
154    fn deref(&self) -> &T {
155        unsafe { &*self.lock.val.get() }
156    }
157}
158
159impl<T> Drop for RwLockReadGuard<'_, T> {
160    fn drop(&mut self) {
161        self.lock.state.fetch_sub(1, Ordering::Release);
162    }
163}
164
165struct RwLockWriteGuard<'a, T> {
166    lock: &'a RwLock<T>,
167}
168
169impl<T> Deref for RwLockWriteGuard<'_, T> {
170    type Target = T;
171
172    fn deref(&self) -> &T {
173        unsafe { &*self.lock.val.get() }
174    }
175}
176
177impl<T> DerefMut for RwLockWriteGuard<'_, T> {
178    fn deref_mut(&mut self) -> &mut T {
179        unsafe { &mut *self.lock.val.get() }
180    }
181}
182
183impl<T> Drop for RwLockWriteGuard<'_, T> {
184    fn drop(&mut self) {
185        match self.lock.state.swap(0, Ordering::Release) {
186            u32::MAX => {}
187            _ => unreachable!(),
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn smoke_once_lock() {
198        let lock = OnceLock::new();
199        assert!(lock.get().is_none());
200        assert_eq!(*lock.get_or_init(|| 1), 1);
201        assert_eq!(*lock.get_or_init(|| 2), 1);
202        assert_eq!(*lock.get_or_init(|| 3), 1);
203        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&1));
204
205        let lock = OnceLock::new();
206        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&3));
207        assert_eq!(*lock.get_or_init(|| 1), 3);
208
209        let lock = OnceLock::new();
210        assert_eq!(lock.get_or_try_init(|| Err(())), Err(()));
211        assert_eq!(*lock.get_or_init(|| 1), 1);
212    }
213
214    #[test]
215    fn smoke_rwlock() {
216        let lock = RwLock::new(1);
217        assert_eq!(*lock.read(), 1);
218
219        let a = lock.read();
220        let b = lock.read();
221        assert_eq!(*a, 1);
222        assert_eq!(*b, 1);
223        drop((a, b));
224
225        assert_eq!(*lock.write(), 1);
226
227        *lock.write() = 4;
228        assert_eq!(*lock.read(), 4);
229        assert_eq!(*lock.write(), 4);
230
231        let a = lock.read();
232        let b = lock.read();
233        assert_eq!(*a, 4);
234        assert_eq!(*b, 4);
235        drop((a, b));
236    }
237
238    #[test]
239    #[should_panic(expected = "concurrent write request")]
240    fn rwlock_panic_read_then_write() {
241        let lock = RwLock::new(1);
242        let _a = lock.read();
243        let _b = lock.write();
244    }
245
246    #[test]
247    #[should_panic(expected = "concurrent read request")]
248    fn rwlock_panic_write_then_read() {
249        let lock = RwLock::new(1);
250        let _a = lock.write();
251        let _b = lock.read();
252    }
253
254    #[test]
255    #[should_panic(expected = "concurrent write request")]
256    fn rwlock_panic_write_then_write() {
257        let lock = RwLock::new(1);
258        let _a = lock.write();
259        let _b = lock.write();
260    }
261}