Skip to main content

wasmtime_wasi/cli/
worker_thread_stdin.rs

1//! Handling for standard in using a worker task.
2//!
3//! Standard input is a global singleton resource for the entire program which
4//! needs special care. Currently this implementation adheres to a few
5//! constraints which make this nontrivial to implement.
6//!
7//! * Any number of guest wasm programs can read stdin. While this doesn't make
8//!   a ton of sense semantically they shouldn't block forever. Instead it's a
9//!   race to see who actually reads which parts of stdin.
10//!
11//! * Data from stdin isn't actually read unless requested. This is done to try
12//!   to be a good neighbor to others running in the process. Under the
13//!   assumption that most programs have one "thing" which reads stdin the
14//!   actual consumption of bytes is delayed until the wasm guest is dynamically
15//!   chosen to be that "thing". Before that data from stdin is not consumed to
16//!   avoid taking it from other components in the process.
17//!
18//! * Tokio's documentation indicates that "interactive stdin" is best done with
19//!   a helper thread to avoid blocking shutdown of the event loop. That's
20//!   respected here where all stdin reading happens on a blocking helper thread
21//!   that, at this time, is never shut down.
22//!
23//! This module is one that's likely to change over time though as new systems
24//! are encountered along with preexisting bugs.
25
26use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37    poll::Pollable,
38    streams::{InputStream, StreamError},
39};
40
41use crate::MAX_READ_SIZE_ALLOC;
42
43// Implementation for tokio::io::Stdin
44impl IsTerminal for tokio::io::Stdin {
45    fn is_terminal(&self) -> bool {
46        std::io::stdin().is_terminal()
47    }
48}
49impl StdinStream for tokio::io::Stdin {
50    fn p2_stream(&self) -> Box<dyn InputStream> {
51        Box::new(WasiStdin)
52    }
53    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
54        Box::new(WasiStdinAsyncRead::Ready)
55    }
56}
57
58// Implementation for std::io::Stdin
59impl IsTerminal for std::io::Stdin {
60    fn is_terminal(&self) -> bool {
61        std::io::IsTerminal::is_terminal(self)
62    }
63}
64impl StdinStream for std::io::Stdin {
65    fn p2_stream(&self) -> Box<dyn InputStream> {
66        Box::new(WasiStdin)
67    }
68    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69        Box::new(WasiStdinAsyncRead::Ready)
70    }
71}
72
73#[derive(Default)]
74struct GlobalStdin {
75    state: Mutex<StdinState>,
76    read_requested: Condvar,
77    read_completed: Notify,
78}
79
80#[derive(Default, Debug)]
81enum StdinState {
82    #[default]
83    ReadNotRequested,
84    ReadRequested(usize),
85    Data(BytesMut),
86    Error(std::io::Error),
87    Closed,
88}
89
90impl GlobalStdin {
91    fn get() -> &'static GlobalStdin {
92        static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
93        STDIN.get_or_init(|| create())
94    }
95}
96
97fn create() -> GlobalStdin {
98    std::thread::spawn(|| {
99        let state = GlobalStdin::get();
100        loop {
101            // Wait for a read to be requested, but don't hold the lock across
102            // the blocking read.
103            let mut lock = state.state.lock().unwrap();
104            lock = state
105                .read_requested
106                .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested(_)))
107                .unwrap();
108
109            // Extract the size hint from the request and cap it to `MAX_READ_SIZE_ALLOC`
110            // to avoid guest-controlled unbounded allocation.
111            // The `.max(1)` ensures a zero-length read is never misinterpreted as EOF.
112            let size_hint = match *lock {
113                StdinState::ReadRequested(size) => size.min(MAX_READ_SIZE_ALLOC).max(1),
114                _ => unreachable!(),
115            };
116            drop(lock);
117
118            let mut bytes = BytesMut::zeroed(size_hint);
119            let (new_state, done) = match std::io::stdin().read(&mut bytes) {
120                Ok(0) => (StdinState::Closed, true),
121                Ok(nbytes) => {
122                    bytes.truncate(nbytes);
123                    (StdinState::Data(bytes), false)
124                }
125                Err(e) => (StdinState::Error(e), true),
126            };
127
128            // After the blocking read completes the state should not have been
129            // tampered with.
130            debug_assert!(matches!(
131                *state.state.lock().unwrap(),
132                StdinState::ReadRequested(_)
133            ));
134            let mut lock = state.state.lock().unwrap();
135            *lock = new_state;
136            state.read_completed.notify_waiters();
137            if done {
138                break;
139            }
140        }
141    });
142
143    GlobalStdin::default()
144}
145
146struct WasiStdin;
147
148#[async_trait::async_trait]
149impl InputStream for WasiStdin {
150    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
151        if size == 0 {
152            return Ok(Bytes::new());
153        }
154        let g = GlobalStdin::get();
155        let mut locked = g.state.lock().unwrap();
156        match mem::replace(&mut *locked, StdinState::ReadRequested(size)) {
157            StdinState::ReadNotRequested => {
158                g.read_requested.notify_one();
159                Ok(Bytes::new())
160            }
161            StdinState::ReadRequested(prev_size) => {
162                // Preserve the larger of the two requested sizes
163                // so the worker thread allocates an adequate buffer.
164                *locked = StdinState::ReadRequested(prev_size.max(size));
165                Ok(Bytes::new())
166            }
167            StdinState::Data(mut data) => {
168                let size = data.len().min(size);
169                let bytes = data.split_to(size);
170                *locked = if data.is_empty() {
171                    StdinState::ReadNotRequested
172                } else {
173                    StdinState::Data(data)
174                };
175                Ok(bytes.freeze())
176            }
177            StdinState::Error(e) => {
178                *locked = StdinState::Closed;
179                Err(StreamError::LastOperationFailed(e.into()))
180            }
181            StdinState::Closed => {
182                *locked = StdinState::Closed;
183                Err(StreamError::Closed)
184            }
185        }
186    }
187}
188
189#[async_trait::async_trait]
190impl Pollable for WasiStdin {
191    async fn ready(&mut self) {
192        let g = GlobalStdin::get();
193
194        // Scope the synchronous `state.lock()` to this block which does not
195        // `.await` inside of it.
196        let notified = {
197            let mut locked = g.state.lock().unwrap();
198            match *locked {
199                // If a read isn't requested yet, use `MAX_READ_SIZE_ALLOC`
200                // as the buffer size since `ready()` doesn't know what size
201                // will be requested by the subsequent `read()` call.
202                StdinState::ReadNotRequested => {
203                    g.read_requested.notify_one();
204                    *locked = StdinState::ReadRequested(MAX_READ_SIZE_ALLOC);
205                    g.read_completed.notified()
206                }
207                StdinState::ReadRequested(_) => g.read_completed.notified(),
208                StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
209            }
210        };
211
212        notified.await;
213    }
214}
215
216enum WasiStdinAsyncRead {
217    Ready,
218    Waiting(Notified<'static>),
219}
220
221impl AsyncRead for WasiStdinAsyncRead {
222    fn poll_read(
223        mut self: Pin<&mut Self>,
224        cx: &mut Context<'_>,
225        buf: &mut ReadBuf<'_>,
226    ) -> Poll<io::Result<()>> {
227        let g = GlobalStdin::get();
228
229        // Everything below is executed under the global stdin lock. It's not
230        // going to block below so that's semantically fine. Optimization-wise
231        // it's probably possible to move this within the loop around just a
232        // small part of reading/writing the state, but that was done
233        // historically and it resulted in lost wakeups with `Notify`, so this
234        // is conservatively hoisted up here.
235        let mut locked = g.state.lock().unwrap();
236
237        // Perform everything below in a `loop` to handle the case that a read
238        // was stolen by another thread, for example, or perhaps a spurious
239        // notification to `Notified`.
240        loop {
241            // If we were previously blocked on reading a "ready" notification,
242            // wait for that notification to complete.
243            if let Some(notified) = self.as_mut().notified_future() {
244                match notified.poll(cx) {
245                    Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
246                    Poll::Pending => break Poll::Pending,
247                }
248            }
249
250            assert!(matches!(*self, WasiStdinAsyncRead::Ready));
251
252            // Once we're in the "ready" state then take a look at the global
253            // state of stdin.
254            match mem::replace(&mut *locked, StdinState::ReadRequested(buf.remaining())) {
255                // If data is available then drain what we can into `buf`.
256                StdinState::Data(mut data) => {
257                    let size = data.len().min(buf.remaining());
258                    let bytes = data.split_to(size);
259                    *locked = if data.is_empty() {
260                        StdinState::ReadNotRequested
261                    } else {
262                        StdinState::Data(data)
263                    };
264                    buf.put_slice(&bytes);
265                    break Poll::Ready(Ok(()));
266                }
267
268                // If stdin failed to be read then we fail with that error and
269                // transition to "closed"
270                StdinState::Error(e) => {
271                    *locked = StdinState::Closed;
272                    break Poll::Ready(Err(e));
273                }
274
275                // If stdin is closed, keep it closed.
276                StdinState::Closed => {
277                    *locked = StdinState::Closed;
278                    break Poll::Ready(Ok(()));
279                }
280
281                // For these states we indicate that a read is requested, if it
282                // wasn't previously requested, and then we transition to
283                // `Waiting` below by falling through outside this `match`.
284                StdinState::ReadNotRequested => {
285                    g.read_requested.notify_one();
286                }
287                StdinState::ReadRequested(prev_size) => {
288                    // Preserve the larger of the previous and current size hint
289                    *locked = StdinState::ReadRequested(prev_size.max(buf.remaining()));
290                }
291            }
292
293            self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
294        }
295    }
296}
297
298impl WasiStdinAsyncRead {
299    fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
300        // SAFETY: this is a pin-projection from `self` to the field `Notified`
301        // internally. Given that `self` is pinned it should be safe to acquire
302        // a pinned version of the internal field.
303        unsafe {
304            match self.get_unchecked_mut() {
305                WasiStdinAsyncRead::Ready => None,
306                WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
307            }
308        }
309    }
310}