wasmtime/runtime/component/concurrent/
abort.rs1use crate::try_mutex::TryMutex;
2use alloc::sync::Arc;
3use core::mem::{self, ManuallyDrop};
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6
7pub struct JoinHandle {
17 state: Arc<TryMutex<JoinState>>,
23}
24
25enum JoinState {
26 Running {
29 waiting_for_abort_signal: Option<Waker>,
32
33 waiting_for_abort_to_complete: Option<Waker>,
36 },
37
38 AbortRequested {
41 waiting_for_abort_to_complete: Option<Waker>,
42 },
43
44 Complete,
47}
48
49impl JoinHandle {
50 pub fn abort(&self) {
60 let mut state = self.state.try_lock().expect("should not be contended");
61
62 match &mut *state {
63 JoinState::Running {
68 waiting_for_abort_signal,
69 waiting_for_abort_to_complete,
70 } => {
71 if let Some(task) = waiting_for_abort_signal.take() {
72 task.wake();
73 }
74
75 *state = JoinState::AbortRequested {
76 waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
77 };
78 }
79
80 JoinState::AbortRequested { .. } | JoinState::Complete => {}
83 }
84 }
85
86 pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
90 where
91 F: Future,
92 {
93 let handle = JoinHandle {
94 state: Arc::new(TryMutex::new(JoinState::Running {
95 waiting_for_abort_signal: None,
96 waiting_for_abort_to_complete: None,
97 })),
98 };
99 let future = JoinHandleFuture {
100 future: ManuallyDrop::new(future),
101 state: handle.state.clone(),
102 };
103 (handle, future)
104 }
105}
106
107impl Future for JoinHandle {
108 type Output = ();
109
110 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
111 let mut state = self
112 .state
113 .try_lock()
114 .expect("this lock should not be contended");
115 match &mut *state {
116 JoinState::Running {
119 waiting_for_abort_to_complete,
120 ..
121 }
122 | JoinState::AbortRequested {
123 waiting_for_abort_to_complete,
124 } => {
125 *waiting_for_abort_to_complete = Some(cx.waker().clone());
126 Poll::Pending
127 }
128
129 JoinState::Complete => Poll::Ready(()),
131 }
132 }
133}
134
135struct JoinHandleFuture<F> {
136 future: ManuallyDrop<F>,
137 state: Arc<TryMutex<JoinState>>,
138}
139
140impl<F> Future for JoinHandleFuture<F>
141where
142 F: Future,
143{
144 type Output = Option<F::Output>;
145
146 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147 let (state, future) = unsafe {
151 let me = self.get_unchecked_mut();
152 (&me.state, Pin::new_unchecked(&mut *me.future))
153 };
154
155 {
158 let mut state = state.try_lock().expect("this lock should not be contended");
159 match &mut *state {
160 JoinState::Running {
161 waiting_for_abort_signal,
162 ..
163 } => {
164 *waiting_for_abort_signal = Some(cx.waker().clone());
165 }
166 JoinState::AbortRequested { .. } | JoinState::Complete => {
167 return Poll::Ready(None);
168 }
169 }
170 }
171
172 future.poll(cx).map(Some)
173 }
174}
175
176impl<F> Drop for JoinHandleFuture<F> {
177 fn drop(&mut self) {
178 unsafe {
185 ManuallyDrop::drop(&mut self.future);
186 }
187
188 let prev = mem::replace(
191 &mut *self.state.try_lock().expect("should not be contended"),
192 JoinState::Complete,
193 );
194 let task = match prev {
195 JoinState::Running {
196 waiting_for_abort_to_complete,
197 ..
198 }
199 | JoinState::AbortRequested {
200 waiting_for_abort_to_complete,
201 } => waiting_for_abort_to_complete,
202 JoinState::Complete => None,
203 };
204 if let Some(task) = task {
205 task.wake();
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::JoinHandle;
213 use std::pin::{Pin, pin};
214 use std::task::{Context, Poll, Waker};
215 use tokio::sync::oneshot;
216
217 fn is_ready<F>(future: Pin<&mut F>) -> bool
218 where
219 F: Future,
220 {
221 match future.poll(&mut Context::from_waker(Waker::noop())) {
222 Poll::Ready(_) => true,
223 Poll::Pending => false,
224 }
225 }
226
227 #[tokio::test]
228 async fn abort_in_progress() {
229 let (tx, rx) = oneshot::channel::<()>();
230 let (mut handle, future) = JoinHandle::run(rx);
231 let mut handle = Pin::new(&mut handle);
232 {
233 let mut future = pin!(future);
234 assert!(!is_ready(future.as_mut()));
235 assert!(!is_ready(handle.as_mut()));
236 handle.abort();
237 assert!(is_ready(future.as_mut()));
238 assert!(!is_ready(handle.as_mut()));
239 assert!(!tx.is_closed());
240 }
241 assert!(is_ready(handle.as_mut()));
242 assert!(tx.is_closed());
243 }
244
245 #[tokio::test]
246 async fn abort_complete() {
247 let (tx, rx) = oneshot::channel::<()>();
248 let (mut handle, future) = JoinHandle::run(rx);
249 let mut handle = Pin::new(&mut handle);
250 tx.send(()).unwrap();
251 assert!(!is_ready(handle.as_mut()));
252 {
253 let mut future = pin!(future);
254 assert!(is_ready(future.as_mut()));
255 assert!(!is_ready(handle.as_mut()));
256 }
257 assert!(is_ready(handle.as_mut()));
258 handle.abort();
259 assert!(is_ready(handle.as_mut()));
260 }
261
262 #[tokio::test]
263 async fn abort_dropped() {
264 let (tx, rx) = oneshot::channel::<()>();
265 let (mut handle, future) = JoinHandle::run(rx);
266 let mut handle = Pin::new(&mut handle);
267 drop(future);
268 assert!(is_ready(handle.as_mut()));
269 handle.abort();
270 assert!(is_ready(handle.as_mut()));
271 assert!(tx.is_closed());
272 }
273
274 #[tokio::test]
275 async fn await_completion() {
276 let (tx, rx) = oneshot::channel::<()>();
277 tx.send(()).unwrap();
278 let (handle, future) = JoinHandle::run(rx);
279 let task = tokio::task::spawn(future);
280 handle.await;
281 task.await.unwrap();
282 }
283
284 #[tokio::test]
285 async fn await_abort() {
286 let (tx, rx) = oneshot::channel::<()>();
287 tx.send(()).unwrap();
288 let (handle, future) = JoinHandle::run(rx);
289 handle.abort();
290 let task = tokio::task::spawn(future);
291 handle.await;
292 task.await.unwrap();
293 }
294}