wasmtime/runtime/component/concurrent/
abort.rs

1use std::mem::{self, ManuallyDrop};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5
6/// Handle to a task which may be used to join on the result of executing it.
7///
8/// This represents a handle to a running task which can be cancelled with
9/// [`JoinHandle::abort`]. The final result and drop of the task can be
10/// determined by `await`-ing this handle.
11///
12/// Note that dropping this handle does not affect the running task it's
13/// connected to. A manual invocation of [`JoinHandle::abort`] is required to
14/// affect the task.
15pub struct JoinHandle {
16    state: Arc<Mutex<JoinState>>,
17}
18
19enum JoinState {
20    /// The task this is connected to is still running and has not completed or
21    /// been dropped.
22    Running {
23        /// The waker that the running task has registered which is signaled
24        /// upon abort.
25        waiting_for_abort_signal: Option<Waker>,
26
27        /// The waker that the `JoinHandle` has registered to await
28        /// destruction of the running task itself.
29        waiting_for_abort_to_complete: Option<Waker>,
30    },
31
32    /// An abort as been requested through an `JoinHandle`. The task specified
33    /// here is used for `Future for JoinHandle`.
34    AbortRequested {
35        waiting_for_abort_to_complete: Option<Waker>,
36    },
37
38    /// The running task has completed, so no need to abort it and nothing else
39    /// needs to wait.
40    Complete,
41}
42
43impl JoinHandle {
44    /// Abort the task.
45    ///
46    /// This flags the connected task should abort in the near future, but note
47    /// that if this is called while the future is being polled then that call
48    /// will still complete.
49    ///
50    /// Note that this `JoinHandle` is itself a `Future` and can be used to
51    /// await the result and destruction of the task that this is associated
52    /// with.
53    pub fn abort(&self) {
54        let mut state = self.state.lock().unwrap();
55
56        match &mut *state {
57            // If this task is still running, then fall through to below to
58            // transition it into the `AbortRequested` state. If present the
59            // waker for the running task is notified to indicate that an abort
60            // signal has been received.
61            JoinState::Running {
62                waiting_for_abort_signal,
63                waiting_for_abort_to_complete,
64            } => {
65                if let Some(task) = waiting_for_abort_signal.take() {
66                    task.wake();
67                }
68
69                *state = JoinState::AbortRequested {
70                    waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
71                };
72            }
73
74            // If this task has already been aborted or has completed, nothing
75            // is left to do.
76            JoinState::AbortRequested { .. } | JoinState::Complete => {}
77        }
78    }
79
80    /// Wraps the `future` provided in a new future which is "abortable" where
81    /// if the returned `JoinHandle` is flagged then the future will resolve
82    /// ASAP with `None` and drop the provided `future`.
83    pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
84    where
85        F: Future,
86    {
87        let handle = JoinHandle {
88            state: Arc::new(Mutex::new(JoinState::Running {
89                waiting_for_abort_signal: None,
90                waiting_for_abort_to_complete: None,
91            })),
92        };
93        let future = JoinHandleFuture {
94            future: ManuallyDrop::new(future),
95            state: handle.state.clone(),
96        };
97        (handle, future)
98    }
99}
100
101impl Future for JoinHandle {
102    type Output = ();
103
104    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105        let mut state = self.state.lock().unwrap();
106        match &mut *state {
107            // If this task is running or still only has requested an abort,
108            // wait further for the task to get dropped.
109            JoinState::Running {
110                waiting_for_abort_to_complete,
111                ..
112            }
113            | JoinState::AbortRequested {
114                waiting_for_abort_to_complete,
115            } => {
116                *waiting_for_abort_to_complete = Some(cx.waker().clone());
117                Poll::Pending
118            }
119
120            // The task is dropped, done!
121            JoinState::Complete => Poll::Ready(()),
122        }
123    }
124}
125
126struct JoinHandleFuture<F> {
127    future: ManuallyDrop<F>,
128    state: Arc<Mutex<JoinState>>,
129}
130
131impl<F> Future for JoinHandleFuture<F>
132where
133    F: Future,
134{
135    type Output = Option<F::Output>;
136
137    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        // SAFETY: this is a pin-projection from `Self` to the state and `Pin`
139        // of the internal future. This is the exclusive access of these fields
140        // apart from the destructor and should be safe.
141        let (state, future) = unsafe {
142            let me = self.get_unchecked_mut();
143            (&me.state, Pin::new_unchecked(&mut *me.future))
144        };
145
146        // First, before polling the future, check to see if we've been
147        // aborted. If not register our task as awaiting such an abort.
148        {
149            let mut state = state.lock().unwrap();
150            match &mut *state {
151                JoinState::Running {
152                    waiting_for_abort_signal,
153                    ..
154                } => {
155                    *waiting_for_abort_signal = Some(cx.waker().clone());
156                }
157                JoinState::AbortRequested { .. } | JoinState::Complete => {
158                    return Poll::Ready(None);
159                }
160            }
161        }
162
163        future.poll(cx).map(Some)
164    }
165}
166
167impl<F> Drop for JoinHandleFuture<F> {
168    fn drop(&mut self) {
169        // SAFETY: this is the exclusive owner of this future and it's safe to
170        // drop here during the owning destructor.
171        //
172        // Note that this explicitly happens before notifying the abort handle
173        // that the task completed so that when the notification goes through
174        // it's guaranteed that the future has been destroyed.
175        unsafe {
176            ManuallyDrop::drop(&mut self.future);
177        }
178
179        // After the future dropped see if there was a task awaiting its
180        // destruction. Simultaneously flag this state as complete.
181        let prev = mem::replace(&mut *self.state.lock().unwrap(), JoinState::Complete);
182        let task = match prev {
183            JoinState::Running {
184                waiting_for_abort_to_complete,
185                ..
186            }
187            | JoinState::AbortRequested {
188                waiting_for_abort_to_complete,
189            } => waiting_for_abort_to_complete,
190            JoinState::Complete => None,
191        };
192        if let Some(task) = task {
193            task.wake();
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::JoinHandle;
201    use std::pin::{Pin, pin};
202    use std::task::{Context, Poll, Waker};
203    use tokio::sync::oneshot;
204
205    fn is_ready<F>(future: Pin<&mut F>) -> bool
206    where
207        F: Future,
208    {
209        match future.poll(&mut Context::from_waker(Waker::noop())) {
210            Poll::Ready(_) => true,
211            Poll::Pending => false,
212        }
213    }
214
215    #[tokio::test]
216    async fn abort_in_progress() {
217        let (tx, rx) = oneshot::channel::<()>();
218        let (mut handle, future) = JoinHandle::run(rx);
219        let mut handle = Pin::new(&mut handle);
220        {
221            let mut future = pin!(future);
222            assert!(!is_ready(future.as_mut()));
223            assert!(!is_ready(handle.as_mut()));
224            handle.abort();
225            assert!(is_ready(future.as_mut()));
226            assert!(!is_ready(handle.as_mut()));
227            assert!(!tx.is_closed());
228        }
229        assert!(is_ready(handle.as_mut()));
230        assert!(tx.is_closed());
231    }
232
233    #[tokio::test]
234    async fn abort_complete() {
235        let (tx, rx) = oneshot::channel::<()>();
236        let (mut handle, future) = JoinHandle::run(rx);
237        let mut handle = Pin::new(&mut handle);
238        tx.send(()).unwrap();
239        assert!(!is_ready(handle.as_mut()));
240        {
241            let mut future = pin!(future);
242            assert!(is_ready(future.as_mut()));
243            assert!(!is_ready(handle.as_mut()));
244        }
245        assert!(is_ready(handle.as_mut()));
246        handle.abort();
247        assert!(is_ready(handle.as_mut()));
248    }
249
250    #[tokio::test]
251    async fn abort_dropped() {
252        let (tx, rx) = oneshot::channel::<()>();
253        let (mut handle, future) = JoinHandle::run(rx);
254        let mut handle = Pin::new(&mut handle);
255        drop(future);
256        assert!(is_ready(handle.as_mut()));
257        handle.abort();
258        assert!(is_ready(handle.as_mut()));
259        assert!(tx.is_closed());
260    }
261
262    #[tokio::test]
263    async fn await_completion() {
264        let (tx, rx) = oneshot::channel::<()>();
265        tx.send(()).unwrap();
266        let (handle, future) = JoinHandle::run(rx);
267        let task = tokio::task::spawn(future);
268        handle.await;
269        task.await.unwrap();
270    }
271
272    #[tokio::test]
273    async fn await_abort() {
274        let (tx, rx) = oneshot::channel::<()>();
275        tx.send(()).unwrap();
276        let (handle, future) = JoinHandle::run(rx);
277        handle.abort();
278        let task = tokio::task::spawn(future);
279        handle.await;
280        task.await.unwrap();
281    }
282}