Skip to main content

wasmtime/runtime/component/concurrent/
abort.rs

1use crate::try_mutex::TryMutex;
2use alloc::sync::Arc;
3use core::mem::{self, ManuallyDrop};
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6
7/// Handle to a task which may be used to join on the result of executing it.
8///
9/// This represents a handle to a running task which can be cancelled with
10/// [`JoinHandle::abort`]. The final result and drop of the task can be
11/// determined by `await`-ing this handle.
12///
13/// Note that dropping this handle does not affect the running task it's
14/// connected to. A manual invocation of [`JoinHandle::abort`] is required to
15/// affect the task.
16pub struct JoinHandle {
17    // Note that at this time the usage of this type within Wasmtime's async
18    // implementation is not expected to ever expose the ability to expose
19    // a situation where this lock can be contended. Everything's bound to the
20    // store and this is largely just used to satisfy compiler bounds. Hence,
21    // lock operations in this module all unwrap.
22    state: Arc<TryMutex<JoinState>>,
23}
24
25enum JoinState {
26    /// The task this is connected to is still running and has not completed or
27    /// been dropped.
28    Running {
29        /// The waker that the running task has registered which is signaled
30        /// upon abort.
31        waiting_for_abort_signal: Option<Waker>,
32
33        /// The waker that the `JoinHandle` has registered to await
34        /// destruction of the running task itself.
35        waiting_for_abort_to_complete: Option<Waker>,
36    },
37
38    /// An abort as been requested through an `JoinHandle`. The task specified
39    /// here is used for `Future for JoinHandle`.
40    AbortRequested {
41        waiting_for_abort_to_complete: Option<Waker>,
42    },
43
44    /// The running task has completed, so no need to abort it and nothing else
45    /// needs to wait.
46    Complete,
47}
48
49impl JoinHandle {
50    /// Abort the task.
51    ///
52    /// This flags the connected task should abort in the near future, but note
53    /// that if this is called while the future is being polled then that call
54    /// will still complete.
55    ///
56    /// Note that this `JoinHandle` is itself a `Future` and can be used to
57    /// await the result and destruction of the task that this is associated
58    /// with.
59    pub fn abort(&self) {
60        let mut state = self.state.try_lock().expect("should not be contended");
61
62        match &mut *state {
63            // If this task is still running, then fall through to below to
64            // transition it into the `AbortRequested` state. If present the
65            // waker for the running task is notified to indicate that an abort
66            // signal has been received.
67            JoinState::Running {
68                waiting_for_abort_signal,
69                waiting_for_abort_to_complete,
70            } => {
71                if let Some(task) = waiting_for_abort_signal.take() {
72                    task.wake();
73                }
74
75                *state = JoinState::AbortRequested {
76                    waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
77                };
78            }
79
80            // If this task has already been aborted or has completed, nothing
81            // is left to do.
82            JoinState::AbortRequested { .. } | JoinState::Complete => {}
83        }
84    }
85
86    /// Wraps the `future` provided in a new future which is "abortable" where
87    /// if the returned `JoinHandle` is flagged then the future will resolve
88    /// ASAP with `None` and drop the provided `future`.
89    pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
90    where
91        F: Future,
92    {
93        let handle = JoinHandle {
94            state: Arc::new(TryMutex::new(JoinState::Running {
95                waiting_for_abort_signal: None,
96                waiting_for_abort_to_complete: None,
97            })),
98        };
99        let future = JoinHandleFuture {
100            future: ManuallyDrop::new(future),
101            state: handle.state.clone(),
102        };
103        (handle, future)
104    }
105}
106
107impl Future for JoinHandle {
108    type Output = ();
109
110    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
111        let mut state = self
112            .state
113            .try_lock()
114            .expect("this lock should not be contended");
115        match &mut *state {
116            // If this task is running or still only has requested an abort,
117            // wait further for the task to get dropped.
118            JoinState::Running {
119                waiting_for_abort_to_complete,
120                ..
121            }
122            | JoinState::AbortRequested {
123                waiting_for_abort_to_complete,
124            } => {
125                *waiting_for_abort_to_complete = Some(cx.waker().clone());
126                Poll::Pending
127            }
128
129            // The task is dropped, done!
130            JoinState::Complete => Poll::Ready(()),
131        }
132    }
133}
134
135struct JoinHandleFuture<F> {
136    future: ManuallyDrop<F>,
137    state: Arc<TryMutex<JoinState>>,
138}
139
140impl<F> Future for JoinHandleFuture<F>
141where
142    F: Future,
143{
144    type Output = Option<F::Output>;
145
146    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
147        // SAFETY: this is a pin-projection from `Self` to the state and `Pin`
148        // of the internal future. This is the exclusive access of these fields
149        // apart from the destructor and should be safe.
150        let (state, future) = unsafe {
151            let me = self.get_unchecked_mut();
152            (&me.state, Pin::new_unchecked(&mut *me.future))
153        };
154
155        // First, before polling the future, check to see if we've been
156        // aborted. If not register our task as awaiting such an abort.
157        {
158            let mut state = state.try_lock().expect("this lock should not be contended");
159            match &mut *state {
160                JoinState::Running {
161                    waiting_for_abort_signal,
162                    ..
163                } => {
164                    *waiting_for_abort_signal = Some(cx.waker().clone());
165                }
166                JoinState::AbortRequested { .. } | JoinState::Complete => {
167                    return Poll::Ready(None);
168                }
169            }
170        }
171
172        future.poll(cx).map(Some)
173    }
174}
175
176impl<F> Drop for JoinHandleFuture<F> {
177    fn drop(&mut self) {
178        // SAFETY: this is the exclusive owner of this future and it's safe to
179        // drop here during the owning destructor.
180        //
181        // Note that this explicitly happens before notifying the abort handle
182        // that the task completed so that when the notification goes through
183        // it's guaranteed that the future has been destroyed.
184        unsafe {
185            ManuallyDrop::drop(&mut self.future);
186        }
187
188        // After the future dropped see if there was a task awaiting its
189        // destruction. Simultaneously flag this state as complete.
190        let prev = mem::replace(
191            &mut *self.state.try_lock().expect("should not be contended"),
192            JoinState::Complete,
193        );
194        let task = match prev {
195            JoinState::Running {
196                waiting_for_abort_to_complete,
197                ..
198            }
199            | JoinState::AbortRequested {
200                waiting_for_abort_to_complete,
201            } => waiting_for_abort_to_complete,
202            JoinState::Complete => None,
203        };
204        if let Some(task) = task {
205            task.wake();
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::JoinHandle;
213    use std::pin::{Pin, pin};
214    use std::task::{Context, Poll, Waker};
215    use tokio::sync::oneshot;
216
217    fn is_ready<F>(future: Pin<&mut F>) -> bool
218    where
219        F: Future,
220    {
221        match future.poll(&mut Context::from_waker(Waker::noop())) {
222            Poll::Ready(_) => true,
223            Poll::Pending => false,
224        }
225    }
226
227    #[tokio::test]
228    async fn abort_in_progress() {
229        let (tx, rx) = oneshot::channel::<()>();
230        let (mut handle, future) = JoinHandle::run(rx);
231        let mut handle = Pin::new(&mut handle);
232        {
233            let mut future = pin!(future);
234            assert!(!is_ready(future.as_mut()));
235            assert!(!is_ready(handle.as_mut()));
236            handle.abort();
237            assert!(is_ready(future.as_mut()));
238            assert!(!is_ready(handle.as_mut()));
239            assert!(!tx.is_closed());
240        }
241        assert!(is_ready(handle.as_mut()));
242        assert!(tx.is_closed());
243    }
244
245    #[tokio::test]
246    async fn abort_complete() {
247        let (tx, rx) = oneshot::channel::<()>();
248        let (mut handle, future) = JoinHandle::run(rx);
249        let mut handle = Pin::new(&mut handle);
250        tx.send(()).unwrap();
251        assert!(!is_ready(handle.as_mut()));
252        {
253            let mut future = pin!(future);
254            assert!(is_ready(future.as_mut()));
255            assert!(!is_ready(handle.as_mut()));
256        }
257        assert!(is_ready(handle.as_mut()));
258        handle.abort();
259        assert!(is_ready(handle.as_mut()));
260    }
261
262    #[tokio::test]
263    async fn abort_dropped() {
264        let (tx, rx) = oneshot::channel::<()>();
265        let (mut handle, future) = JoinHandle::run(rx);
266        let mut handle = Pin::new(&mut handle);
267        drop(future);
268        assert!(is_ready(handle.as_mut()));
269        handle.abort();
270        assert!(is_ready(handle.as_mut()));
271        assert!(tx.is_closed());
272    }
273
274    #[tokio::test]
275    async fn await_completion() {
276        let (tx, rx) = oneshot::channel::<()>();
277        tx.send(()).unwrap();
278        let (handle, future) = JoinHandle::run(rx);
279        let task = tokio::task::spawn(future);
280        handle.await;
281        task.await.unwrap();
282    }
283
284    #[tokio::test]
285    async fn await_abort() {
286        let (tx, rx) = oneshot::channel::<()>();
287        tx.send(()).unwrap();
288        let (handle, future) = JoinHandle::run(rx);
289        handle.abort();
290        let task = tokio::task::spawn(future);
291        handle.await;
292        task.await.unwrap();
293    }
294}