Skip to main content

wasmtime_wasi/cli/
worker_thread_stdin.rs

1//! Handling for standard in using a worker task.
2//!
3//! Standard input is a global singleton resource for the entire program which
4//! needs special care. Currently this implementation adheres to a few
5//! constraints which make this nontrivial to implement.
6//!
7//! * Any number of guest wasm programs can read stdin. While this doesn't make
8//!   a ton of sense semantically they shouldn't block forever. Instead it's a
9//!   race to see who actually reads which parts of stdin.
10//!
11//! * Data from stdin isn't actually read unless requested. This is done to try
12//!   to be a good neighbor to others running in the process. Under the
13//!   assumption that most programs have one "thing" which reads stdin the
14//!   actual consumption of bytes is delayed until the wasm guest is dynamically
15//!   chosen to be that "thing". Before that data from stdin is not consumed to
16//!   avoid taking it from other components in the process.
17//!
18//! * Tokio's documentation indicates that "interactive stdin" is best done with
19//!   a helper thread to avoid blocking shutdown of the event loop. That's
20//!   respected here where all stdin reading happens on a blocking helper thread
21//!   that, at this time, is never shut down.
22//!
23//! This module is one that's likely to change over time though as new systems
24//! are encountered along with preexisting bugs.
25
26use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37    poll::Pollable,
38    streams::{InputStream, StreamError},
39};
40
41// Implementation for tokio::io::Stdin
42impl IsTerminal for tokio::io::Stdin {
43    fn is_terminal(&self) -> bool {
44        std::io::stdin().is_terminal()
45    }
46}
47impl StdinStream for tokio::io::Stdin {
48    fn p2_stream(&self) -> Box<dyn InputStream> {
49        Box::new(WasiStdin)
50    }
51    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
52        Box::new(WasiStdinAsyncRead::Ready)
53    }
54}
55
56// Implementation for std::io::Stdin
57impl IsTerminal for std::io::Stdin {
58    fn is_terminal(&self) -> bool {
59        std::io::IsTerminal::is_terminal(self)
60    }
61}
62impl StdinStream for std::io::Stdin {
63    fn p2_stream(&self) -> Box<dyn InputStream> {
64        Box::new(WasiStdin)
65    }
66    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
67        Box::new(WasiStdinAsyncRead::Ready)
68    }
69}
70
71#[derive(Default)]
72struct GlobalStdin {
73    state: Mutex<StdinState>,
74    read_requested: Condvar,
75    read_completed: Notify,
76}
77
78#[derive(Default, Debug)]
79enum StdinState {
80    #[default]
81    ReadNotRequested,
82    ReadRequested,
83    Data(BytesMut),
84    Error(std::io::Error),
85    Closed,
86}
87
88impl GlobalStdin {
89    fn get() -> &'static GlobalStdin {
90        static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
91        STDIN.get_or_init(|| create())
92    }
93}
94
95fn create() -> GlobalStdin {
96    std::thread::spawn(|| {
97        let state = GlobalStdin::get();
98        loop {
99            // Wait for a read to be requested, but don't hold the lock across
100            // the blocking read.
101            let mut lock = state.state.lock().unwrap();
102            lock = state
103                .read_requested
104                .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
105                .unwrap();
106            drop(lock);
107
108            let mut bytes = BytesMut::zeroed(1024);
109            let (new_state, done) = match std::io::stdin().read(&mut bytes) {
110                Ok(0) => (StdinState::Closed, true),
111                Ok(nbytes) => {
112                    bytes.truncate(nbytes);
113                    (StdinState::Data(bytes), false)
114                }
115                Err(e) => (StdinState::Error(e), true),
116            };
117
118            // After the blocking read completes the state should not have been
119            // tampered with.
120            debug_assert!(matches!(
121                *state.state.lock().unwrap(),
122                StdinState::ReadRequested
123            ));
124            let mut lock = state.state.lock().unwrap();
125            *lock = new_state;
126            state.read_completed.notify_waiters();
127            if done {
128                break;
129            }
130        }
131    });
132
133    GlobalStdin::default()
134}
135
136struct WasiStdin;
137
138#[async_trait::async_trait]
139impl InputStream for WasiStdin {
140    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
141        let g = GlobalStdin::get();
142        let mut locked = g.state.lock().unwrap();
143        match mem::replace(&mut *locked, StdinState::ReadRequested) {
144            StdinState::ReadNotRequested => {
145                g.read_requested.notify_one();
146                Ok(Bytes::new())
147            }
148            StdinState::ReadRequested => Ok(Bytes::new()),
149            StdinState::Data(mut data) => {
150                let size = data.len().min(size);
151                let bytes = data.split_to(size);
152                *locked = if data.is_empty() {
153                    StdinState::ReadNotRequested
154                } else {
155                    StdinState::Data(data)
156                };
157                Ok(bytes.freeze())
158            }
159            StdinState::Error(e) => {
160                *locked = StdinState::Closed;
161                Err(StreamError::LastOperationFailed(e.into()))
162            }
163            StdinState::Closed => {
164                *locked = StdinState::Closed;
165                Err(StreamError::Closed)
166            }
167        }
168    }
169}
170
171#[async_trait::async_trait]
172impl Pollable for WasiStdin {
173    async fn ready(&mut self) {
174        let g = GlobalStdin::get();
175
176        // Scope the synchronous `state.lock()` to this block which does not
177        // `.await` inside of it.
178        let notified = {
179            let mut locked = g.state.lock().unwrap();
180            match *locked {
181                // If a read isn't requested yet
182                StdinState::ReadNotRequested => {
183                    g.read_requested.notify_one();
184                    *locked = StdinState::ReadRequested;
185                    g.read_completed.notified()
186                }
187                StdinState::ReadRequested => g.read_completed.notified(),
188                StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
189            }
190        };
191
192        notified.await;
193    }
194}
195
196enum WasiStdinAsyncRead {
197    Ready,
198    Waiting(Notified<'static>),
199}
200
201impl AsyncRead for WasiStdinAsyncRead {
202    fn poll_read(
203        mut self: Pin<&mut Self>,
204        cx: &mut Context<'_>,
205        buf: &mut ReadBuf<'_>,
206    ) -> Poll<io::Result<()>> {
207        let g = GlobalStdin::get();
208
209        // Everything below is executed under the global stdin lock. It's not
210        // going to block below so that's semantically fine. Optimization-wise
211        // it's probably possible to move this within the loop around just a
212        // small part of reading/writing the state, but that was done
213        // historically and it resulted in lost wakeups with `Notify`, so this
214        // is conservatively hoisted up here.
215        let mut locked = g.state.lock().unwrap();
216
217        // Perform everything below in a `loop` to handle the case that a read
218        // was stolen by another thread, for example, or perhaps a spurious
219        // notification to `Notified`.
220        loop {
221            // If we were previously blocked on reading a "ready" notification,
222            // wait for that notification to complete.
223            if let Some(notified) = self.as_mut().notified_future() {
224                match notified.poll(cx) {
225                    Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
226                    Poll::Pending => break Poll::Pending,
227                }
228            }
229
230            assert!(matches!(*self, WasiStdinAsyncRead::Ready));
231
232            // Once we're in the "ready" state then take a look at the global
233            // state of stdin.
234            match mem::replace(&mut *locked, StdinState::ReadRequested) {
235                // If data is available then drain what we can into `buf`.
236                StdinState::Data(mut data) => {
237                    let size = data.len().min(buf.remaining());
238                    let bytes = data.split_to(size);
239                    *locked = if data.is_empty() {
240                        StdinState::ReadNotRequested
241                    } else {
242                        StdinState::Data(data)
243                    };
244                    buf.put_slice(&bytes);
245                    break Poll::Ready(Ok(()));
246                }
247
248                // If stdin failed to be read then we fail with that error and
249                // transition to "closed"
250                StdinState::Error(e) => {
251                    *locked = StdinState::Closed;
252                    break Poll::Ready(Err(e));
253                }
254
255                // If stdin is closed, keep it closed.
256                StdinState::Closed => {
257                    *locked = StdinState::Closed;
258                    break Poll::Ready(Ok(()));
259                }
260
261                // For these states we indicate that a read is requested, if it
262                // wasn't previously requested, and then we transition to
263                // `Waiting` below by falling through outside this `match`.
264                StdinState::ReadNotRequested => {
265                    g.read_requested.notify_one();
266                }
267                StdinState::ReadRequested => {}
268            }
269
270            self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
271        }
272    }
273}
274
275impl WasiStdinAsyncRead {
276    fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
277        // SAFETY: this is a pin-projection from `self` to the field `Notified`
278        // internally. Given that `self` is pinned it should be safe to acquire
279        // a pinned version of the internal field.
280        unsafe {
281            match self.get_unchecked_mut() {
282                WasiStdinAsyncRead::Ready => None,
283                WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
284            }
285        }
286    }
287}