1use core::cell::UnsafeCell;
18use core::mem::MaybeUninit;
19use core::ops::{Deref, DerefMut};
20use core::sync::atomic::{AtomicU32, AtomicU8, Ordering};
21
22pub struct OnceLock<T> {
23 val: UnsafeCell<MaybeUninit<T>>,
24 state: AtomicU8,
25}
26
27unsafe impl<T: Send> Send for OnceLock<T> {}
28unsafe impl<T: Sync> Sync for OnceLock<T> {}
29
30const UNINITIALIZED: u8 = 0;
31const INITIALIZING: u8 = 1;
32const INITIALIZED: u8 = 2;
33
34impl<T> OnceLock<T> {
35 pub const fn new() -> OnceLock<T> {
36 OnceLock {
37 state: AtomicU8::new(UNINITIALIZED),
38 val: UnsafeCell::new(MaybeUninit::uninit()),
39 }
40 }
41
42 pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
43 if let Some(ret) = self.get() {
44 return ret;
45 }
46 self.try_init::<()>(|| Ok(f())).unwrap()
47 }
48
49 pub fn get_or_try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
50 if let Some(ret) = self.get() {
51 return Ok(ret);
52 }
53 self.try_init(f)
54 }
55
56 fn get(&self) -> Option<&T> {
57 if self.state.load(Ordering::Acquire) == INITIALIZED {
58 Some(unsafe { (*self.val.get()).assume_init_ref() })
59 } else {
60 None
61 }
62 }
63
64 #[cold]
65 fn try_init<E>(&self, f: impl FnOnce() -> Result<T, E>) -> Result<&T, E> {
66 match self.state.compare_exchange(
67 UNINITIALIZED,
68 INITIALIZING,
69 Ordering::Acquire,
70 Ordering::Acquire,
71 ) {
72 Ok(UNINITIALIZED) => match f() {
73 Ok(val) => {
74 let ret = unsafe { &*(*self.val.get()).write(val) };
75 let prev = self.state.swap(INITIALIZED, Ordering::Release);
76 assert_eq!(prev, INITIALIZING);
77 Ok(ret)
78 }
79 Err(e) => match self.state.swap(UNINITIALIZED, Ordering::Release) {
80 INITIALIZING => Err(e),
81 _ => unreachable!(),
82 },
83 },
84 Err(INITIALIZING) => panic!("concurrent initialization only allowed with `std`"),
85 Err(INITIALIZED) => Ok(self.get().unwrap()),
86 _ => unreachable!(),
87 }
88 }
89}
90
91impl<T> Default for OnceLock<T> {
92 fn default() -> OnceLock<T> {
93 OnceLock::new()
94 }
95}
96
97#[derive(Debug, Default)]
98pub struct RwLock<T> {
99 val: UnsafeCell<T>,
100 state: AtomicU32,
101}
102
103unsafe impl<T: Send> Send for RwLock<T> {}
104unsafe impl<T: Send + Sync> Sync for RwLock<T> {}
105
106impl<T> RwLock<T> {
107 pub const fn new(val: T) -> RwLock<T> {
108 RwLock {
109 val: UnsafeCell::new(val),
110 state: AtomicU32::new(0),
111 }
112 }
113
114 pub fn read(&self) -> impl Deref<Target = T> + '_ {
115 const READER_LIMIT: u32 = u32::MAX / 2;
116 match self
117 .state
118 .fetch_update(Ordering::Acquire, Ordering::Acquire, |x| match x {
119 u32::MAX => None,
120 n => {
121 let next = n + 1;
122 if next < READER_LIMIT {
123 Some(next)
124 } else {
125 None
126 }
127 }
128 }) {
129 Ok(_) => RwLockReadGuard { lock: self },
130 Err(_) => panic!(
131 "concurrent read request while locked for writing, must use `std` to avoid panic"
132 ),
133 }
134 }
135
136 pub fn write(&self) -> impl DerefMut<Target = T> + '_ {
137 match self
138 .state
139 .compare_exchange(0, u32::MAX, Ordering::Acquire, Ordering::Relaxed)
140 {
141 Ok(0) => RwLockWriteGuard { lock: self },
142 _ => panic!("concurrent write request, must use `std` to avoid panicking"),
143 }
144 }
145}
146
147struct RwLockReadGuard<'a, T> {
148 lock: &'a RwLock<T>,
149}
150
151impl<T> Deref for RwLockReadGuard<'_, T> {
152 type Target = T;
153
154 fn deref(&self) -> &T {
155 unsafe { &*self.lock.val.get() }
156 }
157}
158
159impl<T> Drop for RwLockReadGuard<'_, T> {
160 fn drop(&mut self) {
161 self.lock.state.fetch_sub(1, Ordering::Release);
162 }
163}
164
165struct RwLockWriteGuard<'a, T> {
166 lock: &'a RwLock<T>,
167}
168
169impl<T> Deref for RwLockWriteGuard<'_, T> {
170 type Target = T;
171
172 fn deref(&self) -> &T {
173 unsafe { &*self.lock.val.get() }
174 }
175}
176
177impl<T> DerefMut for RwLockWriteGuard<'_, T> {
178 fn deref_mut(&mut self) -> &mut T {
179 unsafe { &mut *self.lock.val.get() }
180 }
181}
182
183impl<T> Drop for RwLockWriteGuard<'_, T> {
184 fn drop(&mut self) {
185 match self.lock.state.swap(0, Ordering::Release) {
186 u32::MAX => {}
187 _ => unreachable!(),
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn smoke_once_lock() {
198 let lock = OnceLock::new();
199 assert!(lock.get().is_none());
200 assert_eq!(*lock.get_or_init(|| 1), 1);
201 assert_eq!(*lock.get_or_init(|| 2), 1);
202 assert_eq!(*lock.get_or_init(|| 3), 1);
203 assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&1));
204
205 let lock = OnceLock::new();
206 assert_eq!(lock.get_or_try_init::<()>(|| Ok(3)), Ok(&3));
207 assert_eq!(*lock.get_or_init(|| 1), 3);
208
209 let lock = OnceLock::new();
210 assert_eq!(lock.get_or_try_init(|| Err(())), Err(()));
211 assert_eq!(*lock.get_or_init(|| 1), 1);
212 }
213
214 #[test]
215 fn smoke_rwlock() {
216 let lock = RwLock::new(1);
217 assert_eq!(*lock.read(), 1);
218
219 let a = lock.read();
220 let b = lock.read();
221 assert_eq!(*a, 1);
222 assert_eq!(*b, 1);
223 drop((a, b));
224
225 assert_eq!(*lock.write(), 1);
226
227 *lock.write() = 4;
228 assert_eq!(*lock.read(), 4);
229 assert_eq!(*lock.write(), 4);
230
231 let a = lock.read();
232 let b = lock.read();
233 assert_eq!(*a, 4);
234 assert_eq!(*b, 4);
235 drop((a, b));
236 }
237
238 #[test]
239 #[should_panic(expected = "concurrent write request")]
240 fn rwlock_panic_read_then_write() {
241 let lock = RwLock::new(1);
242 let _a = lock.read();
243 let _b = lock.write();
244 }
245
246 #[test]
247 #[should_panic(expected = "concurrent read request")]
248 fn rwlock_panic_write_then_read() {
249 let lock = RwLock::new(1);
250 let _a = lock.write();
251 let _b = lock.read();
252 }
253
254 #[test]
255 #[should_panic(expected = "concurrent write request")]
256 fn rwlock_panic_write_then_write() {
257 let lock = RwLock::new(1);
258 let _a = lock.write();
259 let _b = lock.write();
260 }
261}