1#![cfg_attr(
18    all(feature = "std", not(test)),
19    expect(dead_code, reason = "not used, but typechecked")
20)]
21
22use core::cell::UnsafeCell;
23use core::mem::MaybeUninit;
24use core::ops::{Deref, DerefMut};
25use core::sync::atomic::{AtomicU8, Ordering};
26
27pub struct OnceLock<T> {
28    val: UnsafeCell<MaybeUninit<T>>,
29    state: AtomicU8,
30    mutex: raw::Mutex,
31}
32
33unsafe impl<T: Send> Send for OnceLock<T> {}
34unsafe impl<T: Sync> Sync for OnceLock<T> {}
35
36const UNINITIALIZED: u8 = 0;
37const INITIALIZING: u8 = 1;
38const INITIALIZED: u8 = 2;
39
40impl<T> OnceLock<T> {
41    pub const fn new() -> OnceLock<T> {
42        OnceLock {
43            state: AtomicU8::new(UNINITIALIZED),
44            val: UnsafeCell::new(MaybeUninit::uninit()),
45            mutex: raw::Mutex::new(),
46        }
47    }
48
49    pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
50        if let Some(ret) = self.get() {
51            return ret;
52        }
53        self.try_init::<()>(|| Ok(f())).unwrap()
54    }
55
56    pub fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
57        if let Some(ret) = self.get() {
58            return Ok(ret);
59        }
60        self.try_init(f)
61    }
62
63    fn get(&self) -> Option<&T> {
64        if self.state.load(Ordering::Acquire) == INITIALIZED {
65            Some(unsafe { (*self.val.get()).assume_init_ref() })
67        } else {
68            None
69        }
70    }
71
72    #[cold]
73    fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
74        let _guard = OnceLockGuard::new(&self.mutex);
75
76        match self.state.load(Ordering::Acquire) {
78            UNINITIALIZED => {
79                self.state.store(INITIALIZING, Ordering::Release);
80                match f() {
81                    Ok(val) => {
82                        let ret = unsafe { &*(*self.val.get()).write(val) };
84                        self.state.store(INITIALIZED, Ordering::Release);
85                        Ok(ret)
86                    }
87                    Err(e) => {
88                        self.state.store(UNINITIALIZED, Ordering::Release);
89                        Err(e)
90                    }
91                }
92            }
93            INITIALIZED => {
94                Ok(unsafe { (*self.val.get()).assume_init_ref() })
96            }
97            _ => panic!("concurrent initialization"),
98        }
99    }
100}
101
102impl<T> Drop for OnceLock<T> {
103    fn drop(&mut self) {
104        if self.state.load(Ordering::Acquire) == INITIALIZED {
105            unsafe { (*self.val.get()).assume_init_drop() };
107        }
108    }
109}
110
111impl<T> Default for OnceLock<T> {
112    fn default() -> OnceLock<T> {
113        OnceLock::new()
114    }
115}
116
117struct OnceLockGuard<'a> {
118    lock: &'a raw::Mutex,
119}
120
121impl<'a> OnceLockGuard<'a> {
122    fn new(lock: &'a raw::Mutex) -> OnceLockGuard<'a> {
123        lock.lock();
124        OnceLockGuard { lock }
125    }
126}
127
128impl Drop for OnceLockGuard<'_> {
129    fn drop(&mut self) {
130        unsafe {
132            self.lock.unlock();
133        }
134    }
135}
136
137#[derive(Debug)]
138pub struct RwLock<T> {
139    val: UnsafeCell<T>,
140    lock: raw::RwLock,
141}
142
143unsafe impl<T: Send> Send for RwLock<T> {}
144unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
145
146impl<T> RwLock<T> {
147    pub const fn new(val: T) -> RwLock<T> {
148        RwLock {
149            val: UnsafeCell::new(val),
150            lock: raw::RwLock::new(),
151        }
152    }
153
154    pub fn read(&self) -> impl Deref<Target = T> + '_ {
155        self.lock.read();
156        RwLockReadGuard { lock: self }
157    }
158
159    pub fn write(&self) -> impl DerefMut<Target = T> + '_ {
160        self.lock.write();
161        RwLockWriteGuard { lock: self }
162    }
163}
164
165impl<T: Default> Default for RwLock<T> {
166    fn default() -> RwLock<T> {
167        RwLock::new(T::default())
168    }
169}
170
171struct RwLockReadGuard<'a, T> {
172    lock: &'a RwLock<T>,
173}
174
175impl<T> Deref for RwLockReadGuard<'_, T> {
176    type Target = T;
177
178    fn deref(&self) -> &T {
179        unsafe { &*self.lock.val.get() }
181    }
182}
183
184impl<T> Drop for RwLockReadGuard<'_, T> {
185    fn drop(&mut self) {
186        unsafe {
189            self.lock.lock.read_unlock();
190        }
191    }
192}
193
194struct RwLockWriteGuard<'a, T> {
195    lock: &'a RwLock<T>,
196}
197
198impl<T> Deref for RwLockWriteGuard<'_, T> {
199    type Target = T;
200
201    fn deref(&self) -> &T {
202        unsafe { &*self.lock.val.get() }
204    }
205}
206
207impl<T> DerefMut for RwLockWriteGuard<'_, T> {
208    fn deref_mut(&mut self) -> &mut T {
209        unsafe { &mut *self.lock.val.get() }
211    }
212}
213
214impl<T> Drop for RwLockWriteGuard<'_, T> {
215    fn drop(&mut self) {
216        unsafe {
219            self.lock.lock.write_unlock();
220        }
221    }
222}
223
224#[cfg(not(has_custom_sync))]
225use panic_on_contention as raw;
226#[cfg(not(has_custom_sync))]
227mod panic_on_contention {
228    use core::sync::atomic::{AtomicBool, AtomicU32, Ordering};
229
230    #[derive(Debug)]
231    pub struct Mutex {
232        locked: AtomicBool,
233    }
234
235    impl Mutex {
236        pub const fn new() -> Mutex {
237            Mutex {
238                locked: AtomicBool::new(false),
239            }
240        }
241
242        pub fn lock(&self) {
243            if self.locked.swap(true, Ordering::Acquire) {
244                panic!(
245                    "concurrent lock request, must use `std` or `custom-sync-primitives` features to avoid panicking"
246                );
247            }
248        }
249
250        pub unsafe fn unlock(&self) {
251            self.locked.store(false, Ordering::Release);
252        }
253    }
254
255    #[derive(Debug)]
256    pub struct RwLock {
257        state: AtomicU32,
258    }
259
260    impl RwLock {
261        pub const fn new() -> RwLock {
262            RwLock {
263                state: AtomicU32::new(0),
264            }
265        }
266
267        pub fn read(&self) {
268            const READER_LIMIT: u32 = u32::MAX / 2;
269            match self
270                .state
271                .fetch_update(Ordering::Acquire, Ordering::Acquire, |x| match x {
272                    u32::MAX => None,
273                    n => {
274                        let next = n + 1;
275                        if next < READER_LIMIT {
276                            Some(next)
277                        } else {
278                            None
279                        }
280                    }
281                }) {
282                Ok(_) => {}
283                Err(_) => panic!(
284                    "concurrent read request while locked for writing, must use `std` or `custom-sync-primitives` features to avoid panic"
285                ),
286            }
287        }
288
289        pub unsafe fn read_unlock(&self) {
290            self.state.fetch_sub(1, Ordering::Release);
291        }
292
293        pub fn write(&self) {
294            match self
295                .state
296                .compare_exchange(0, u32::MAX, Ordering::Acquire, Ordering::Relaxed)
297            {
298                Ok(0) => {}
299                _ => panic!(
300                    "concurrent write request, must use `std` or `custom-sync-primitives` features to avoid panicking"
301                ),
302            }
303        }
304
305        pub unsafe fn write_unlock(&self) {
306            match self.state.swap(0, Ordering::Release) {
307                u32::MAX => {}
308                _ => unreachable!(),
309            }
310        }
311    }
312}
313
314#[cfg(has_custom_sync)]
315use custom_capi as raw;
316#[cfg(has_custom_sync)]
317mod custom_capi {
318    use crate::runtime::vm::capi;
319    use core::cell::UnsafeCell;
320
321    #[derive(Debug)]
322    pub struct Mutex {
323        storage: UnsafeCell<usize>,
324    }
325
326    impl Mutex {
327        pub const fn new() -> Mutex {
328            Mutex {
329                storage: UnsafeCell::new(0),
330            }
331        }
332
333        pub fn lock(&self) {
334            unsafe {
335                capi::wasmtime_sync_lock_acquire(self.storage.get());
336            }
337        }
338
339        pub unsafe fn unlock(&self) {
340            unsafe {
341                capi::wasmtime_sync_lock_release(self.storage.get());
342            }
343        }
344    }
345
346    impl Drop for Mutex {
347        fn drop(&mut self) {
348            unsafe {
351                capi::wasmtime_sync_lock_free(self.storage.get());
352            }
353        }
354    }
355
356    #[derive(Debug)]
357    pub struct RwLock {
358        storage: UnsafeCell<usize>,
359    }
360
361    impl RwLock {
362        pub const fn new() -> RwLock {
363            RwLock {
364                storage: UnsafeCell::new(0),
365            }
366        }
367
368        pub fn read(&self) {
369            unsafe {
370                capi::wasmtime_sync_rwlock_read(self.storage.get());
371            }
372        }
373
374        pub unsafe fn read_unlock(&self) {
375            unsafe {
376                capi::wasmtime_sync_rwlock_read_release(self.storage.get());
377            }
378        }
379
380        pub fn write(&self) {
381            unsafe {
382                capi::wasmtime_sync_rwlock_write(self.storage.get());
383            }
384        }
385
386        pub unsafe fn write_unlock(&self) {
387            unsafe {
388                capi::wasmtime_sync_rwlock_write_release(self.storage.get());
389            }
390        }
391    }
392
393    impl Drop for RwLock {
394        fn drop(&mut self) {
395            unsafe {
396                capi::wasmtime_sync_rwlock_free(self.storage.get());
397            }
398        }
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn smoke_once_lock() {
408        let lock = OnceLock::new();
409        assert!(lock.get().is_none());
410        assert_eq!(*lock.get_or_init(|| 1), 1);
411        assert_eq!(*lock.get_or_init(|| 2), 1);
412        assert_eq!(*lock.get_or_init(|| 3), 1);
413        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&1));
414
415        let lock = OnceLock::new();
416        assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&3));
417        assert_eq!(*lock.get_or_init(|| 1), 3);
418
419        let lock = OnceLock::new();
420        assert_eq!(lock.get_or_try_init(|| Err(())), Err(()));
421        assert_eq!(*lock.get_or_init(|| 1), 1);
422    }
423
424    #[test]
425    fn smoke_rwlock() {
426        let lock = RwLock::new(1);
427        assert_eq!(*lock.read(), 1);
428
429        let a = lock.read();
430        let b = lock.read();
431        assert_eq!(*a, 1);
432        assert_eq!(*b, 1);
433        drop((a, b));
434
435        assert_eq!(*lock.write(), 1);
436
437        *lock.write() = 4;
438        assert_eq!(*lock.read(), 4);
439        assert_eq!(*lock.write(), 4);
440
441        let a = lock.read();
442        let b = lock.read();
443        assert_eq!(*a, 4);
444        assert_eq!(*b, 4);
445        drop((a, b));
446    }
447
448    #[test]
449    #[should_panic(expected = "concurrent write request")]
450    fn rwlock_panic_read_then_write() {
451        let lock = RwLock::new(1);
452        let _a = lock.read();
453        let _b = lock.write();
454    }
455
456    #[test]
457    #[should_panic(expected = "concurrent read request")]
458    fn rwlock_panic_write_then_read() {
459        let lock = RwLock::new(1);
460        let _a = lock.write();
461        let _b = lock.read();
462    }
463
464    #[test]
465    #[should_panic(expected = "concurrent write request")]
466    fn rwlock_panic_write_then_write() {
467        let lock = RwLock::new(1);
468        let _a = lock.write();
469        let _b = lock.write();
470    }
471}