wasmtime_wasi/cli/
locked_async.rs1use crate::cli::{IsTerminal, StdinStream, StdoutStream};
2use crate::p2;
3use bytes::Bytes;
4use std::mem;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll, ready};
8use tokio::io::{self, AsyncRead, AsyncWrite};
9use tokio::sync::{Mutex, OwnedMutexGuard};
10use wasmtime_wasi_io::streams::{InputStream, OutputStream};
11
12trait SharedHandleReady: Send + Sync + 'static {
13 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>;
14}
15
16impl SharedHandleReady for p2::pipe::AsyncWriteStream {
17 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
18 <Self>::poll_ready(self, cx)
19 }
20}
21
22impl SharedHandleReady for p2::pipe::AsyncReadStream {
23 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
24 <Self>::poll_ready(self, cx)
25 }
26}
27
28pub struct AsyncStdinStream(Arc<Mutex<p2::pipe::AsyncReadStream>>);
57
58impl AsyncStdinStream {
59 pub fn new(s: impl AsyncRead + Send + Sync + 'static) -> Self {
60 Self(Arc::new(Mutex::new(p2::pipe::AsyncReadStream::new(s))))
61 }
62}
63
64impl StdinStream for AsyncStdinStream {
65 fn p2_stream(&self) -> Box<dyn InputStream> {
66 Box::new(Self(self.0.clone()))
67 }
68 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69 Box::new(StdioHandle::Ready(self.0.clone()))
70 }
71}
72
73impl IsTerminal for AsyncStdinStream {
74 fn is_terminal(&self) -> bool {
75 false
76 }
77}
78
79#[async_trait::async_trait]
80impl InputStream for AsyncStdinStream {
81 fn read(&mut self, size: usize) -> Result<bytes::Bytes, p2::StreamError> {
82 match self.0.try_lock() {
83 Ok(mut stream) => stream.read(size),
84 Err(_) => Err(p2::StreamError::trap("concurrent reads are not supported")),
85 }
86 }
87 fn skip(&mut self, size: usize) -> Result<usize, p2::StreamError> {
88 match self.0.try_lock() {
89 Ok(mut stream) => stream.skip(size),
90 Err(_) => Err(p2::StreamError::trap("concurrent skips are not supported")),
91 }
92 }
93 async fn cancel(&mut self) {
94 if let Some(mutex) = Arc::get_mut(&mut self.0) {
96 match mutex.try_lock() {
97 Ok(mut stream) => stream.cancel().await,
98 Err(_) => {}
99 }
100 }
101 }
102}
103
104#[async_trait::async_trait]
105impl p2::Pollable for AsyncStdinStream {
106 async fn ready(&mut self) {
107 self.0.lock().await.ready().await
108 }
109}
110
111impl AsyncRead for StdioHandle<p2::pipe::AsyncReadStream> {
112 fn poll_read(
113 mut self: Pin<&mut Self>,
114 cx: &mut Context<'_>,
115 buf: &mut io::ReadBuf<'_>,
116 ) -> Poll<io::Result<()>> {
117 match ready!(self.as_mut().poll(cx, |g| g.read(buf.remaining()))) {
118 Some(Ok(bytes)) => {
119 buf.put_slice(&bytes);
120 Poll::Ready(Ok(()))
121 }
122 Some(Err(e)) => Poll::Ready(Err(e)),
123 None => Poll::Ready(Ok(())),
126 }
127 }
128}
129
130pub struct AsyncStdoutStream(Arc<Mutex<p2::pipe::AsyncWriteStream>>);
138
139impl AsyncStdoutStream {
140 pub fn new(budget: usize, s: impl AsyncWrite + Send + Sync + 'static) -> Self {
141 Self(Arc::new(Mutex::new(p2::pipe::AsyncWriteStream::new(
142 budget, s,
143 ))))
144 }
145}
146
147impl StdoutStream for AsyncStdoutStream {
148 fn p2_stream(&self) -> Box<dyn OutputStream> {
149 Box::new(Self(self.0.clone()))
150 }
151 fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
152 Box::new(StdioHandle::Ready(self.0.clone()))
153 }
154}
155
156impl IsTerminal for AsyncStdoutStream {
157 fn is_terminal(&self) -> bool {
158 false
159 }
160}
161
162#[async_trait::async_trait]
178impl OutputStream for AsyncStdoutStream {
179 fn check_write(&mut self) -> Result<usize, p2::StreamError> {
180 match self.0.try_lock() {
181 Ok(mut stream) => stream.check_write(),
182 Err(_) => Err(p2::StreamError::trap("concurrent writes are not supported")),
183 }
184 }
185 fn write(&mut self, bytes: Bytes) -> Result<(), p2::StreamError> {
186 match self.0.try_lock() {
187 Ok(mut stream) => stream.write(bytes),
188 Err(_) => Err(p2::StreamError::trap("concurrent writes not supported yet")),
189 }
190 }
191 fn flush(&mut self) -> Result<(), p2::StreamError> {
192 match self.0.try_lock() {
193 Ok(mut stream) => stream.flush(),
194 Err(_) => Err(p2::StreamError::trap(
195 "concurrent flushes not supported yet",
196 )),
197 }
198 }
199 async fn cancel(&mut self) {
200 if let Some(mutex) = Arc::get_mut(&mut self.0) {
202 match mutex.try_lock() {
203 Ok(mut stream) => stream.cancel().await,
204 Err(_) => {}
205 }
206 }
207 }
208}
209
210#[async_trait::async_trait]
211impl p2::Pollable for AsyncStdoutStream {
212 async fn ready(&mut self) {
213 self.0.lock().await.ready().await
214 }
215}
216
217impl AsyncWrite for StdioHandle<p2::pipe::AsyncWriteStream> {
218 fn poll_write(
219 self: Pin<&mut Self>,
220 cx: &mut Context<'_>,
221 buf: &[u8],
222 ) -> Poll<io::Result<usize>> {
223 match ready!(self.poll(cx, |i| i.write(Bytes::copy_from_slice(buf)))) {
224 Some(Ok(())) => Poll::Ready(Ok(buf.len())),
225 Some(Err(e)) => Poll::Ready(Err(e)),
226 None => Poll::Ready(Ok(0)),
227 }
228 }
229 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
230 match ready!(self.poll(cx, |i| i.flush())) {
231 Some(result) => Poll::Ready(result),
232 None => Poll::Ready(Ok(())),
233 }
234 }
235 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
236 Poll::Ready(Ok(()))
237 }
238}
239
240enum StdioHandle<S> {
251 Ready(Arc<Mutex<S>>),
252 Locking(Box<dyn Future<Output = OwnedMutexGuard<S>> + Send + Sync>),
253 Locked(OwnedMutexGuard<S>),
254 Closed,
255}
256
257impl<S> StdioHandle<S>
258where
259 S: SharedHandleReady,
260{
261 fn poll<T>(
262 mut self: Pin<&mut Self>,
263 cx: &mut Context<'_>,
264 op: impl FnOnce(&mut S) -> p2::StreamResult<T>,
265 ) -> Poll<Option<io::Result<T>>> {
266 if let StdioHandle::Ready(lock) = &*self {
269 self.set(StdioHandle::Locking(Box::new(lock.clone().lock_owned())));
270 }
271
272 if let Some(lock) = self.as_mut().as_locking() {
275 let guard = ready!(lock.poll(cx));
276 self.set(StdioHandle::Locked(guard));
277 }
278
279 let mut guard = match self.as_mut().take_guard() {
280 Some(guard) => guard,
281 None => return Poll::Ready(None),
284 };
285
286 match guard.poll_ready(cx) {
289 Poll::Ready(()) => {}
290
291 Poll::Pending => {
294 self.set(StdioHandle::Locked(guard));
295 return Poll::Pending;
296 }
297 }
298
299 match op(&mut guard) {
301 Ok(result) => {
304 self.set(StdioHandle::Ready(OwnedMutexGuard::mutex(&guard).clone()));
305 Poll::Ready(Some(Ok(result)))
306 }
307
308 Err(p2::StreamError::Closed) => Poll::Ready(None),
311
312 Err(p2::StreamError::LastOperationFailed(e)) => {
317 Poll::Ready(Some(Err(e.downcast().unwrap())))
318 }
319
320 Err(p2::StreamError::Trap(_)) => unreachable!(),
322 }
323 }
324
325 fn as_locking(
326 self: Pin<&mut Self>,
327 ) -> Option<Pin<&mut dyn Future<Output = OwnedMutexGuard<S>>>> {
328 unsafe {
331 match self.get_unchecked_mut() {
332 StdioHandle::Locking(future) => Some(Pin::new_unchecked(&mut **future)),
333 _ => None,
334 }
335 }
336 }
337
338 fn take_guard(self: Pin<&mut Self>) -> Option<OwnedMutexGuard<S>> {
339 if !matches!(*self, StdioHandle::Locked(_)) {
340 return None;
341 }
342 unsafe {
345 match mem::replace(self.get_unchecked_mut(), StdioHandle::Closed) {
346 StdioHandle::Locked(guard) => Some(guard),
347 _ => unreachable!(),
348 }
349 }
350 }
351}