wasmtime_wasi/cli/
worker_thread_stdin.rs

1//! Handling for standard in using a worker task.
2//!
3//! Standard input is a global singleton resource for the entire program which
4//! needs special care. Currently this implementation adheres to a few
5//! constraints which make this nontrivial to implement.
6//!
7//! * Any number of guest wasm programs can read stdin. While this doesn't make
8//!   a ton of sense semantically they shouldn't block forever. Instead it's a
9//!   race to see who actually reads which parts of stdin.
10//!
11//! * Data from stdin isn't actually read unless requested. This is done to try
12//!   to be a good neighbor to others running in the process. Under the
13//!   assumption that most programs have one "thing" which reads stdin the
14//!   actual consumption of bytes is delayed until the wasm guest is dynamically
15//!   chosen to be that "thing". Before that data from stdin is not consumed to
16//!   avoid taking it from other components in the process.
17//!
18//! * Tokio's documentation indicates that "interactive stdin" is best done with
19//!   a helper thread to avoid blocking shutdown of the event loop. That's
20//!   respected here where all stdin reading happens on a blocking helper thread
21//!   that, at this time, is never shut down.
22//!
23//! This module is one that's likely to change over time though as new systems
24//! are encountered along with preexisting bugs.
25
26use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37    poll::Pollable,
38    streams::{InputStream, StreamError},
39};
40
41// Implementation for tokio::io::Stdin
42impl IsTerminal for tokio::io::Stdin {
43    fn is_terminal(&self) -> bool {
44        std::io::stdin().is_terminal()
45    }
46}
47impl StdinStream for tokio::io::Stdin {
48    fn p2_stream(&self) -> Box<dyn InputStream> {
49        Box::new(WasiStdin)
50    }
51    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
52        Box::new(WasiStdinAsyncRead::Ready)
53    }
54}
55
56// Implementation for std::io::Stdin
57impl IsTerminal for std::io::Stdin {
58    fn is_terminal(&self) -> bool {
59        std::io::IsTerminal::is_terminal(self)
60    }
61}
62impl StdinStream for std::io::Stdin {
63    fn p2_stream(&self) -> Box<dyn InputStream> {
64        Box::new(WasiStdin)
65    }
66    fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
67        Box::new(WasiStdinAsyncRead::Ready)
68    }
69}
70
71#[derive(Default)]
72struct GlobalStdin {
73    state: Mutex<StdinState>,
74    read_requested: Condvar,
75    read_completed: Notify,
76}
77
78#[derive(Default, Debug)]
79enum StdinState {
80    #[default]
81    ReadNotRequested,
82    ReadRequested,
83    Data(BytesMut),
84    Error(std::io::Error),
85    Closed,
86}
87
88impl GlobalStdin {
89    fn get() -> &'static GlobalStdin {
90        static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
91        STDIN.get_or_init(|| create())
92    }
93}
94
95fn create() -> GlobalStdin {
96    std::thread::spawn(|| {
97        let state = GlobalStdin::get();
98        loop {
99            // Wait for a read to be requested, but don't hold the lock across
100            // the blocking read.
101            let mut lock = state.state.lock().unwrap();
102            lock = state
103                .read_requested
104                .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
105                .unwrap();
106            drop(lock);
107
108            let mut bytes = BytesMut::zeroed(1024);
109            let (new_state, done) = match std::io::stdin().read(&mut bytes) {
110                Ok(0) => (StdinState::Closed, true),
111                Ok(nbytes) => {
112                    bytes.truncate(nbytes);
113                    (StdinState::Data(bytes), false)
114                }
115                Err(e) => (StdinState::Error(e), true),
116            };
117
118            // After the blocking read completes the state should not have been
119            // tampered with.
120            debug_assert!(matches!(
121                *state.state.lock().unwrap(),
122                StdinState::ReadRequested
123            ));
124            *state.state.lock().unwrap() = new_state;
125            state.read_completed.notify_waiters();
126            if done {
127                break;
128            }
129        }
130    });
131
132    GlobalStdin::default()
133}
134
135struct WasiStdin;
136
137#[async_trait::async_trait]
138impl InputStream for WasiStdin {
139    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
140        let g = GlobalStdin::get();
141        let mut locked = g.state.lock().unwrap();
142        match mem::replace(&mut *locked, StdinState::ReadRequested) {
143            StdinState::ReadNotRequested => {
144                g.read_requested.notify_one();
145                Ok(Bytes::new())
146            }
147            StdinState::ReadRequested => Ok(Bytes::new()),
148            StdinState::Data(mut data) => {
149                let size = data.len().min(size);
150                let bytes = data.split_to(size);
151                *locked = if data.is_empty() {
152                    StdinState::ReadNotRequested
153                } else {
154                    StdinState::Data(data)
155                };
156                Ok(bytes.freeze())
157            }
158            StdinState::Error(e) => {
159                *locked = StdinState::Closed;
160                Err(StreamError::LastOperationFailed(e.into()))
161            }
162            StdinState::Closed => {
163                *locked = StdinState::Closed;
164                Err(StreamError::Closed)
165            }
166        }
167    }
168}
169
170#[async_trait::async_trait]
171impl Pollable for WasiStdin {
172    async fn ready(&mut self) {
173        let g = GlobalStdin::get();
174
175        // Scope the synchronous `state.lock()` to this block which does not
176        // `.await` inside of it.
177        let notified = {
178            let mut locked = g.state.lock().unwrap();
179            match *locked {
180                // If a read isn't requested yet
181                StdinState::ReadNotRequested => {
182                    g.read_requested.notify_one();
183                    *locked = StdinState::ReadRequested;
184                    g.read_completed.notified()
185                }
186                StdinState::ReadRequested => g.read_completed.notified(),
187                StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
188            }
189        };
190
191        notified.await;
192    }
193}
194
195enum WasiStdinAsyncRead {
196    Ready,
197    Waiting(Notified<'static>),
198}
199
200impl AsyncRead for WasiStdinAsyncRead {
201    fn poll_read(
202        mut self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        buf: &mut ReadBuf<'_>,
205    ) -> Poll<io::Result<()>> {
206        let g = GlobalStdin::get();
207
208        // Perform everything below in a `loop` to handle the case that a read
209        // was stolen by another thread, for example, or perhaps a spurious
210        // notification to `Notified`.
211        loop {
212            // If we were previously blocked on reading a "ready" notification,
213            // wait for that notification to complete.
214            if let Some(notified) = self.as_mut().notified_future() {
215                match notified.poll(cx) {
216                    Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
217                    Poll::Pending => break Poll::Pending,
218                }
219            }
220
221            assert!(matches!(*self, WasiStdinAsyncRead::Ready));
222
223            // Once we're in the "ready" state then take a look at the global
224            // state of stdin.
225            let mut locked = g.state.lock().unwrap();
226            match mem::replace(&mut *locked, StdinState::ReadRequested) {
227                // If data is available then drain what we can into `buf`.
228                StdinState::Data(mut data) => {
229                    let size = data.len().min(buf.remaining());
230                    let bytes = data.split_to(size);
231                    *locked = if data.is_empty() {
232                        StdinState::ReadNotRequested
233                    } else {
234                        StdinState::Data(data)
235                    };
236                    buf.put_slice(&bytes);
237                    break Poll::Ready(Ok(()));
238                }
239
240                // If stdin failed to be read then we fail with that error and
241                // transition to "closed"
242                StdinState::Error(e) => {
243                    *locked = StdinState::Closed;
244                    break Poll::Ready(Err(e));
245                }
246
247                // If stdin is closed, keep it closed.
248                StdinState::Closed => {
249                    *locked = StdinState::Closed;
250                    break Poll::Ready(Ok(()));
251                }
252
253                // For these states we indicate that a read is requested, if it
254                // wasn't previously requested, and then we transition to
255                // `Waiting` below by falling through outside this `match`.
256                StdinState::ReadNotRequested => {
257                    g.read_requested.notify_one();
258                }
259                StdinState::ReadRequested => {}
260            }
261
262            self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
263
264            // Intentionally drop the lock after the `notified()` future
265            // creation just above as to work correctly this needs to happen
266            // within the lock.
267            drop(locked);
268        }
269    }
270}
271
272impl WasiStdinAsyncRead {
273    fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
274        // SAFETY: this is a pin-projection from `self` to the field `Notified`
275        // internally. Given that `self` is pinned it should be safe to acquire
276        // a pinned version of the internal field.
277        unsafe {
278            match self.get_unchecked_mut() {
279                WasiStdinAsyncRead::Ready => None,
280                WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
281            }
282        }
283    }
284}