1#![cfg_attr(
18 all(feature = "std", not(test)),
19 expect(dead_code, reason = "not used, but typechecked")
20)]
21
22use core::cell::UnsafeCell;
23use core::mem::MaybeUninit;
24use core::ops::{Deref, DerefMut};
25use core::sync::atomic::{AtomicU8, AtomicU32, Ordering};
26
27pub struct OnceLock<T> {
28 val: UnsafeCell<MaybeUninit<T>>,
29 state: AtomicU8,
30}
31
32unsafe impl<T: Send> Send for OnceLock<T> {}
33unsafe impl<T: Sync> Sync for OnceLock<T> {}
34
35const UNINITIALIZED: u8 = 0;
36const INITIALIZING: u8 = 1;
37const INITIALIZED: u8 = 2;
38
39impl<T> OnceLock<T> {
40 pub const fn new() -> OnceLock<T> {
41 OnceLock {
42 state: AtomicU8::new(UNINITIALIZED),
43 val: UnsafeCell::new(MaybeUninit::uninit()),
44 }
45 }
46
47 pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
48 if let Some(ret) = self.get() {
49 return ret;
50 }
51 self.try_init::<()>(|| Ok(f())).unwrap()
52 }
53
54 pub fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
55 if let Some(ret) = self.get() {
56 return Ok(ret);
57 }
58 self.try_init(f)
59 }
60
61 fn get(&self) -> Option<&T> {
62 if self.state.load(Ordering::Acquire) == INITIALIZED {
63 Some(unsafe { (*self.val.get()).assume_init_ref() })
64 } else {
65 None
66 }
67 }
68
69 #[cold]
70 fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
71 match self.state.compare_exchange(
72 UNINITIALIZED,
73 INITIALIZING,
74 Ordering::Acquire,
75 Ordering::Acquire,
76 ) {
77 Ok(UNINITIALIZED) => match f() {
78 Ok(val) => {
79 let ret = unsafe { &*(*self.val.get()).write(val) };
80 let prev = self.state.swap(INITIALIZED, Ordering::Release);
81 assert_eq!(prev, INITIALIZING);
82 Ok(ret)
83 }
84 Err(e) => match self.state.swap(UNINITIALIZED, Ordering::Release) {
85 INITIALIZING => Err(e),
86 _ => unreachable!(),
87 },
88 },
89 Err(INITIALIZING) => panic!("concurrent initialization only allowed with `std`"),
90 Err(INITIALIZED) => Ok(self.get().unwrap()),
91 _ => unreachable!(),
92 }
93 }
94}
95
96impl<T> Default for OnceLock<T> {
97 fn default() -> OnceLock<T> {
98 OnceLock::new()
99 }
100}
101
102#[derive(Debug, Default)]
103pub struct RwLock<T> {
104 val: UnsafeCell<T>,
105 state: AtomicU32,
106}
107
108unsafe impl<T: Send> Send for RwLock<T> {}
109unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
110
111impl<T> RwLock<T> {
112 pub const fn new(val: T) -> RwLock<T> {
113 RwLock {
114 val: UnsafeCell::new(val),
115 state: AtomicU32::new(0),
116 }
117 }
118
119 pub fn read(&self) -> impl Deref<Target = T> + '_ {
120 const READER_LIMIT: u32 = u32::MAX / 2;
121 match self
122 .state
123 .fetch_update(Ordering::Acquire, Ordering::Acquire, |x| match x {
124 u32::MAX => None,
125 n => {
126 let next = n + 1;
127 if next < READER_LIMIT {
128 Some(next)
129 } else {
130 None
131 }
132 }
133 }) {
134 Ok(_) => RwLockReadGuard { lock: self },
135 Err(_) => panic!(
136 "concurrent read request while locked for writing, must use `std` to avoid panic"
137 ),
138 }
139 }
140
141 pub fn write(&self) -> impl DerefMut<Target = T> + '_ {
142 match self
143 .state
144 .compare_exchange(0, u32::MAX, Ordering::Acquire, Ordering::Relaxed)
145 {
146 Ok(0) => RwLockWriteGuard { lock: self },
147 _ => panic!("concurrent write request, must use `std` to avoid panicking"),
148 }
149 }
150}
151
152struct RwLockReadGuard<'a, T> {
153 lock: &'a RwLock<T>,
154}
155
156impl<T> Deref for RwLockReadGuard<'_, T> {
157 type Target = T;
158
159 fn deref(&self) -> &T {
160 unsafe { &*self.lock.val.get() }
161 }
162}
163
164impl<T> Drop for RwLockReadGuard<'_, T> {
165 fn drop(&mut self) {
166 self.lock.state.fetch_sub(1, Ordering::Release);
167 }
168}
169
170struct RwLockWriteGuard<'a, T> {
171 lock: &'a RwLock<T>,
172}
173
174impl<T> Deref for RwLockWriteGuard<'_, T> {
175 type Target = T;
176
177 fn deref(&self) -> &T {
178 unsafe { &*self.lock.val.get() }
179 }
180}
181
182impl<T> DerefMut for RwLockWriteGuard<'_, T> {
183 fn deref_mut(&mut self) -> &mut T {
184 unsafe { &mut *self.lock.val.get() }
185 }
186}
187
188impl<T> Drop for RwLockWriteGuard<'_, T> {
189 fn drop(&mut self) {
190 match self.lock.state.swap(0, Ordering::Release) {
191 u32::MAX => {}
192 _ => unreachable!(),
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn smoke_once_lock() {
203 let lock = OnceLock::new();
204 assert!(lock.get().is_none());
205 assert_eq!(*lock.get_or_init(|| 1), 1);
206 assert_eq!(*lock.get_or_init(|| 2), 1);
207 assert_eq!(*lock.get_or_init(|| 3), 1);
208 assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&1));
209
210 let lock = OnceLock::new();
211 assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&3));
212 assert_eq!(*lock.get_or_init(|| 1), 3);
213
214 let lock = OnceLock::new();
215 assert_eq!(lock.get_or_try_init(|| Err(())), Err(()));
216 assert_eq!(*lock.get_or_init(|| 1), 1);
217 }
218
219 #[test]
220 fn smoke_rwlock() {
221 let lock = RwLock::new(1);
222 assert_eq!(*lock.read(), 1);
223
224 let a = lock.read();
225 let b = lock.read();
226 assert_eq!(*a, 1);
227 assert_eq!(*b, 1);
228 drop((a, b));
229
230 assert_eq!(*lock.write(), 1);
231
232 *lock.write() = 4;
233 assert_eq!(*lock.read(), 4);
234 assert_eq!(*lock.write(), 4);
235
236 let a = lock.read();
237 let b = lock.read();
238 assert_eq!(*a, 4);
239 assert_eq!(*b, 4);
240 drop((a, b));
241 }
242
243 #[test]
244 #[should_panic(expected = "concurrent write request")]
245 fn rwlock_panic_read_then_write() {
246 let lock = RwLock::new(1);
247 let _a = lock.read();
248 let _b = lock.write();
249 }
250
251 #[test]
252 #[should_panic(expected = "concurrent read request")]
253 fn rwlock_panic_write_then_read() {
254 let lock = RwLock::new(1);
255 let _a = lock.write();
256 let _b = lock.read();
257 }
258
259 #[test]
260 #[should_panic(expected = "concurrent write request")]
261 fn rwlock_panic_write_then_write() {
262 let lock = RwLock::new(1);
263 let _a = lock.write();
264 let _b = lock.write();
265 }
266}