wasmtime/runtime/component/concurrent/
abort.rs1use std::mem::{self, ManuallyDrop};
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll, Waker};
5
6pub struct JoinHandle {
16 state: Arc<Mutex<JoinState>>,
17}
18
19enum JoinState {
20 Running {
23 waiting_for_abort_signal: Option<Waker>,
26
27 waiting_for_abort_to_complete: Option<Waker>,
30 },
31
32 AbortRequested {
35 waiting_for_abort_to_complete: Option<Waker>,
36 },
37
38 Complete,
41}
42
43impl JoinHandle {
44 pub fn abort(&self) {
54 let mut state = self.state.lock().unwrap();
55
56 match &mut *state {
57 JoinState::Running {
62 waiting_for_abort_signal,
63 waiting_for_abort_to_complete,
64 } => {
65 if let Some(task) = waiting_for_abort_signal.take() {
66 task.wake();
67 }
68
69 *state = JoinState::AbortRequested {
70 waiting_for_abort_to_complete: waiting_for_abort_to_complete.take(),
71 };
72 }
73
74 JoinState::AbortRequested { .. } | JoinState::Complete => {}
77 }
78 }
79
80 pub(crate) fn run<F>(future: F) -> (JoinHandle, impl Future<Output = Option<F::Output>>)
84 where
85 F: Future,
86 {
87 let handle = JoinHandle {
88 state: Arc::new(Mutex::new(JoinState::Running {
89 waiting_for_abort_signal: None,
90 waiting_for_abort_to_complete: None,
91 })),
92 };
93 let future = JoinHandleFuture {
94 future: ManuallyDrop::new(future),
95 state: handle.state.clone(),
96 };
97 (handle, future)
98 }
99}
100
101impl Future for JoinHandle {
102 type Output = ();
103
104 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105 let mut state = self.state.lock().unwrap();
106 match &mut *state {
107 JoinState::Running {
110 waiting_for_abort_to_complete,
111 ..
112 }
113 | JoinState::AbortRequested {
114 waiting_for_abort_to_complete,
115 } => {
116 *waiting_for_abort_to_complete = Some(cx.waker().clone());
117 Poll::Pending
118 }
119
120 JoinState::Complete => Poll::Ready(()),
122 }
123 }
124}
125
126struct JoinHandleFuture<F> {
127 future: ManuallyDrop<F>,
128 state: Arc<Mutex<JoinState>>,
129}
130
131impl<F> Future for JoinHandleFuture<F>
132where
133 F: Future,
134{
135 type Output = Option<F::Output>;
136
137 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 let (state, future) = unsafe {
142 let me = self.get_unchecked_mut();
143 (&me.state, Pin::new_unchecked(&mut *me.future))
144 };
145
146 {
149 let mut state = state.lock().unwrap();
150 match &mut *state {
151 JoinState::Running {
152 waiting_for_abort_signal,
153 ..
154 } => {
155 *waiting_for_abort_signal = Some(cx.waker().clone());
156 }
157 JoinState::AbortRequested { .. } | JoinState::Complete => {
158 return Poll::Ready(None);
159 }
160 }
161 }
162
163 future.poll(cx).map(Some)
164 }
165}
166
167impl<F> Drop for JoinHandleFuture<F> {
168 fn drop(&mut self) {
169 unsafe {
176 ManuallyDrop::drop(&mut self.future);
177 }
178
179 let prev = mem::replace(&mut *self.state.lock().unwrap(), JoinState::Complete);
182 let task = match prev {
183 JoinState::Running {
184 waiting_for_abort_to_complete,
185 ..
186 }
187 | JoinState::AbortRequested {
188 waiting_for_abort_to_complete,
189 } => waiting_for_abort_to_complete,
190 JoinState::Complete => None,
191 };
192 if let Some(task) = task {
193 task.wake();
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::JoinHandle;
201 use std::pin::{Pin, pin};
202 use std::task::{Context, Poll, Waker};
203 use tokio::sync::oneshot;
204
205 fn is_ready<F>(future: Pin<&mut F>) -> bool
206 where
207 F: Future,
208 {
209 match future.poll(&mut Context::from_waker(Waker::noop())) {
210 Poll::Ready(_) => true,
211 Poll::Pending => false,
212 }
213 }
214
215 #[tokio::test]
216 async fn abort_in_progress() {
217 let (tx, rx) = oneshot::channel::<()>();
218 let (mut handle, future) = JoinHandle::run(rx);
219 let mut handle = Pin::new(&mut handle);
220 {
221 let mut future = pin!(future);
222 assert!(!is_ready(future.as_mut()));
223 assert!(!is_ready(handle.as_mut()));
224 handle.abort();
225 assert!(is_ready(future.as_mut()));
226 assert!(!is_ready(handle.as_mut()));
227 assert!(!tx.is_closed());
228 }
229 assert!(is_ready(handle.as_mut()));
230 assert!(tx.is_closed());
231 }
232
233 #[tokio::test]
234 async fn abort_complete() {
235 let (tx, rx) = oneshot::channel::<()>();
236 let (mut handle, future) = JoinHandle::run(rx);
237 let mut handle = Pin::new(&mut handle);
238 tx.send(()).unwrap();
239 assert!(!is_ready(handle.as_mut()));
240 {
241 let mut future = pin!(future);
242 assert!(is_ready(future.as_mut()));
243 assert!(!is_ready(handle.as_mut()));
244 }
245 assert!(is_ready(handle.as_mut()));
246 handle.abort();
247 assert!(is_ready(handle.as_mut()));
248 }
249
250 #[tokio::test]
251 async fn abort_dropped() {
252 let (tx, rx) = oneshot::channel::<()>();
253 let (mut handle, future) = JoinHandle::run(rx);
254 let mut handle = Pin::new(&mut handle);
255 drop(future);
256 assert!(is_ready(handle.as_mut()));
257 handle.abort();
258 assert!(is_ready(handle.as_mut()));
259 assert!(tx.is_closed());
260 }
261
262 #[tokio::test]
263 async fn await_completion() {
264 let (tx, rx) = oneshot::channel::<()>();
265 tx.send(()).unwrap();
266 let (handle, future) = JoinHandle::run(rx);
267 let task = tokio::task::spawn(future);
268 handle.await;
269 task.await.unwrap();
270 }
271
272 #[tokio::test]
273 async fn await_abort() {
274 let (tx, rx) = oneshot::channel::<()>();
275 tx.send(()).unwrap();
276 let (handle, future) = JoinHandle::run(rx);
277 handle.abort();
278 let task = tokio::task::spawn(future);
279 handle.await;
280 task.await.unwrap();
281 }
282}