wasmtime_wasi/cli/
locked_async.rs

1use crate::cli::{IsTerminal, StdinStream, StdoutStream};
2use crate::p2;
3use bytes::Bytes;
4use std::mem;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8use tokio::io::{self, AsyncRead, AsyncWrite};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10use wasmtime_wasi_io::streams::{InputStream, OutputStream};
11
12trait SharedHandleReady: Send + Sync + 'static {
13    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
14}
15
16impl SharedHandleReady for p2::pipe::AsyncWriteStream {
17    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
18        <Self>::poll_ready(self, cx)
19    }
20}
21
22impl SharedHandleReady for p2::pipe::AsyncReadStream {
23    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
24        <Self>::poll_ready(self, cx)
25    }
26}
27
28/// An impl of [`StdinStream`] built on top of [`AsyncRead`].
29//
30// Note the usage of `tokio::sync::Mutex` here as opposed to a
31// `std::sync::Mutex`. This is intentionally done to implement the `Pollable`
32// variant of this trait. Note that in doing so we're left with the quandry of
33// how to implement methods of `InputStream` since those methods are not
34// `async`. They're currently implemented with `try_lock`, which then raises the
35// question of what to do on contention. Currently traps are returned.
36//
37// Why should it be ok to return a trap? In general concurrency/contention
38// shouldn't return a trap since it should be able to happen normally. The
39// current assumption, though, is that WASI stdin/stdout streams are special
40// enough that the contention case should never come up in practice. Currently
41// in WASI there is no actually concurrency, there's just the items in a single
42// `Store` and that store owns all of its I/O in a single Tokio task. There's no
43// means to actually spawn multiple Tokio tasks that use the same store. This
44// means at the very least that there's zero parallelism. Due to the lack of
45// multiple tasks that also means that there's no concurrency either.
46//
47// This `AsyncStdinStream` wrapper is only intended to be used by the WASI
48// bindings themselves. It's possible for the host to take this and work with it
49// on its own task, but that's niche enough it's not designed for.
50//
51// Overall that means that the guest is either calling `Pollable` or
52// `InputStream` methods. This means that there should never be contention
53// between the two at this time. This may all change in the future with WASI
54// 0.3, but perhaps we'll have a better story for stdio at that time (see the
55// doc block on the `OutputStream` impl below)
56pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);
57
58impl AsyncStdinStream {
59    pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {
60        Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))
61    }
62}
63
64impl StdinStream for AsyncStdinStream {
65    fn p2_stream(&self) -> Box<dyn InputStream> {
66        Box::new(Self(self.0.clone()))
67    }
68    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69        Box::new(StdioHandle::Ready(self.0.clone()))
70    }
71}
72
73impl IsTerminal for AsyncStdinStream {
74    fn is_terminal(&self) -> bool {
75        false
76    }
77}
78
79#[async_trait::async_trait]
80impl InputStream for AsyncStdinStream {
81    fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {
82        match self.0.try_lock() {
83            Ok(mut stream) => stream.read(size),
84            Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),
85        }
86    }
87    fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {
88        match self.0.try_lock() {
89            Ok(mut stream) => stream.skip(size),
90            Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),
91        }
92    }
93    async fn cancel(&mut self) {
94        // Cancel the inner stream if we're the last reference to it:
95        if let Some(mutex) = Arc::get_mut(&mut self.0) {
96            match mutex.try_lock() {
97                Ok(mut stream) => stream.cancel().await,
98                Err(_) => {}
99            }
100        }
101    }
102}
103
104#[async_trait::async_trait]
105impl p2::Pollable for AsyncStdinStream {
106    async fn ready(&mut self) {
107        self.0.lock().await.ready().await
108    }
109}
110
111impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {
112    fn poll_read(
113        mut self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115        buf: &mut io::ReadBuf<'_>,
116    ) -> Poll<io::Result<()>> {
117        match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {
118            Some(Ok(bytes)) => {
119                buf.put_slice(&bytes);
120                Poll::Ready(Ok(()))
121            }
122            Some(Err(e)) => Poll::Ready(Err(e)),
123            // If the guard can't be acquired that means that this stream is
124            // closed, so return that we're ready without filling in data.
125            None => Poll::Ready(Ok(())),
126        }
127    }
128}
129
130/// A wrapper of [`crate::p2::pipe::AsyncWriteStream`] that implements
131/// [`StdoutStream`]. Note that the [`OutputStream`] impl for this is not
132/// correct when used for interleaved async IO.
133//
134// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to
135// the `try_lock()` calls below in the implementation of `OutputStream`. For
136// more information see the documentation on `AsyncStdinStream`.
137pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);
138
139impl AsyncStdoutStream {
140    pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {
141        Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(
142            budget, s,
143        ))))
144    }
145}
146
147impl StdoutStream for AsyncStdoutStream {
148    fn p2_stream(&self) -> Box<dyn OutputStream> {
149        Box::new(Self(self.0.clone()))
150    }
151    fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
152        Box::new(StdioHandle::Ready(self.0.clone()))
153    }
154}
155
156impl IsTerminal for AsyncStdoutStream {
157    fn is_terminal(&self) -> bool {
158        false
159    }
160}
161
162// This implementation is known to be bogus. All check-writes and writes are
163// directed at the same underlying stream. The check-write/write protocol does
164// require the size returned by a check-write to be accepted by write, even if
165// other side-effects happen between those calls, and this implementation
166// permits another view (created by StdoutStream::stream()) of the same
167// underlying stream to accept a write which will invalidate a prior
168// check-write of another view.
169// Ultimately, the Std{in,out}Stream::stream() methods exist because many
170// different places in a linked component (which may itself contain many
171// modules) may need to access stdio without any coordination to keep those
172// accesses all using pointing to the same resource. So, we allow many
173// resources to be created. We have the reasonable expectation that programs
174// won't attempt to interleave async IO from these disparate uses of stdio.
175// If that expectation doesn't turn out to be true, and you find yourself at
176// this comment to correct it: sorry about that.
177#[async_trait::async_trait]
178impl OutputStream for AsyncStdoutStream {
179    fn check_write(&mut self) -> Result<usize, p2::StreamError> {
180        match self.0.try_lock() {
181            Ok(mut stream) => stream.check_write(),
182            Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),
183        }
184    }
185    fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {
186        match self.0.try_lock() {
187            Ok(mut stream) => stream.write(bytes),
188            Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),
189        }
190    }
191    fn flush(&mut self) -> Result<(), p2::StreamError> {
192        match self.0.try_lock() {
193            Ok(mut stream) => stream.flush(),
194            Err(_) => Err(p2::StreamError::trap(
195                "concurrent flushes not supported yet",
196            )),
197        }
198    }
199    async fn cancel(&mut self) {
200        // Cancel the inner stream if we're the last reference to it:
201        if let Some(mutex) = Arc::get_mut(&mut self.0) {
202            match mutex.try_lock() {
203                Ok(mut stream) => stream.cancel().await,
204                Err(_) => {}
205            }
206        }
207    }
208}
209
210#[async_trait::async_trait]
211impl p2::Pollable for AsyncStdoutStream {
212    async fn ready(&mut self) {
213        self.0.lock().await.ready().await
214    }
215}
216
217impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {
218    fn poll_write(
219        self: Pin<&mut Self>,
220        cx: &mut Context<'_>,
221        buf: &[u8],
222    ) -> Poll<io::Result<usize>> {
223        match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {
224            Some(Ok(())) => Poll::Ready(Ok(buf.len())),
225            Some(Err(e)) => Poll::Ready(Err(e)),
226            None => Poll::Ready(Ok(0)),
227        }
228    }
229    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230        match ready!(self.poll(cx, |i| i.flush())) {
231            Some(result) => Poll::Ready(result),
232            None => Poll::Ready(Ok(())),
233        }
234    }
235    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236        Poll::Ready(Ok(()))
237    }
238}
239
240/// State necessary for effectively transforming `Arc<Mutex<dyn
241/// {Input,Output}Stream>>` into `Async{Read,Write}`.
242///
243/// This is a beast and inefficient. It should get the job done in theory but
244/// one must truly ask oneself at some point "but at what cost".
245///
246/// More seriously, it's unclear if this is the best way to transform a single
247/// `AsyncRead` into a "multiple `AsyncRead`". This certainly is an attempt and
248/// the hope is that everything here is private enough that we can refactor as
249/// necessary in the future without causing much churn.
250enum StdioHandle<S> {
251    Ready(Arc<Mutex<S>>),
252    Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),
253    Locked(OwnedMutexGuard<S>),
254    Closed,
255}
256
257impl<S> StdioHandle<S>
258where
259    S: SharedHandleReady,
260{
261    fn poll<T>(
262        mut self: Pin<&mut Self>,
263        cx: &mut Context<'_>,
264        op: impl FnOnce(&mut S) -> p2::StreamResult<T>,
265    ) -> Poll<Option<io::Result<T>>> {
266        // If we don't currently have the lock on this handle, initiate the
267        // lock acquisition.
268        if let StdioHandle::Ready(lock) = &*self {
269            self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));
270        }
271
272        // If we're in the process of locking this handle, wait for that to
273        // finish.
274        if let Some(lock) = self.as_mut().as_locking() {
275            let guard = ready!(lock.poll(cx));
276            self.set(StdioHandle::Locked(guard));
277        }
278
279        let mut guard = match self.as_mut().take_guard() {
280            Some(guard) => guard,
281            // If the guard can't be acquired that means that this stream is
282            // closed, so return that we're ready without filling in data.
283            None => return Poll::Ready(None),
284        };
285
286        // Wait for our locked stream to be ready, resetting to the "locked"
287        // state if it's not quite ready yet.
288        match guard.poll_ready(cx) {
289            Poll::Ready(()) => {}
290
291            // If the read isn't ready yet then restore our "locked" state
292            // since we haven't finished, then return pending.
293            Poll::Pending => {
294                self.set(StdioHandle::Locked(guard));
295                return Poll::Pending;
296            }
297        }
298
299        // Perform the I/O and delegate on the result.
300        match op(&mut guard) {
301            // The I/O succeeded so relinquish the lock on this stream by
302            // transitioning back to the "Ready" state.
303            Ok(result) => {
304                self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));
305                Poll::Ready(Some(Ok(result)))
306            }
307
308            // The stream is closed, and `take_guard` above already set the
309            // closed state, so return nothing indicating the closure.
310            Err(p2::StreamError::Closed) => Poll::Ready(None),
311
312            // The stream failed so propagate the error. Errors should only
313            // come from the underlying I/O object and thus should cast
314            // successfully. Additionally `take_guard` replaced our state
315            // with "closed" above which is the desired state at this point.
316            Err(p2::StreamError::LastOperationFailed(e)) => {
317                Poll::Ready(Some(Err(e.downcast().unwrap())))
318            }
319
320            // Shouldn't be possible to produce a trap here.
321            Err(p2::StreamError::Trap(_)) => unreachable!(),
322        }
323    }
324
325    fn as_locking(
326        self: Pin<&mut Self>,
327    ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {
328        // SAFETY: this is a pin-projection from `self` into the `Locking`
329        // field.
330        unsafe {
331            match self.get_unchecked_mut() {
332                StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),
333                _ => None,
334            }
335        }
336    }
337
338    fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {
339        if !matches!(*self, StdioHandle::Locked(_)) {
340            return None;
341        }
342        // SAFETY: the `Locked` arm is safe to move as it's an invariant of this
343        // type that it's not pinned.
344        unsafe {
345            match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {
346                StdioHandle::Locked(guard) => Some(guard),
347                _ => unreachable!(),
348            }
349        }
350    }
351}