wasmtime_wasi/p2/stdio/
worker_thread_stdin.rs

1//! Handling for standard in using a worker task.
2//!
3//! Standard input is a global singleton resource for the entire program which
4//! needs special care. Currently this implementation adheres to a few
5//! constraints which make this nontrivial to implement.
6//!
7//! * Any number of guest wasm programs can read stdin. While this doesn't make
8//!   a ton of sense semantically they shouldn't block forever. Instead it's a
9//!   race to see who actually reads which parts of stdin.
10//!
11//! * Data from stdin isn't actually read unless requested. This is done to try
12//!   to be a good neighbor to others running in the process. Under the
13//!   assumption that most programs have one "thing" which reads stdin the
14//!   actual consumption of bytes is delayed until the wasm guest is dynamically
15//!   chosen to be that "thing". Before that data from stdin is not consumed to
16//!   avoid taking it from other components in the process.
17//!
18//! * Tokio's documentation indicates that "interactive stdin" is best done with
19//!   a helper thread to avoid blocking shutdown of the event loop. That's
20//!   respected here where all stdin reading happens on a blocking helper thread
21//!   that, at this time, is never shut down.
22//!
23//! This module is one that's likely to change over time though as new systems
24//! are encountered along with preexisting bugs.
25
26use crate::cli::IsTerminal;
27use crate::p2::stdio::StdinStream;
28use bytes::{Bytes, BytesMut};
29use std::io::Read;
30use std::mem;
31use std::sync::{Condvar, Mutex, OnceLock};
32use tokio::sync::Notify;
33use wasmtime_wasi_io::{
34    poll::Pollable,
35    streams::{InputStream, StreamError},
36};
37
38#[derive(Default)]
39struct GlobalStdin {
40    state: Mutex<StdinState>,
41    read_requested: Condvar,
42    read_completed: Notify,
43}
44
45#[derive(Default, Debug)]
46enum StdinState {
47    #[default]
48    ReadNotRequested,
49    ReadRequested,
50    Data(BytesMut),
51    Error(std::io::Error),
52    Closed,
53}
54
55impl GlobalStdin {
56    fn get() -> &'static GlobalStdin {
57        static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
58        STDIN.get_or_init(|| create())
59    }
60}
61
62fn create() -> GlobalStdin {
63    std::thread::spawn(|| {
64        let state = GlobalStdin::get();
65        loop {
66            // Wait for a read to be requested, but don't hold the lock across
67            // the blocking read.
68            let mut lock = state.state.lock().unwrap();
69            lock = state
70                .read_requested
71                .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
72                .unwrap();
73            drop(lock);
74
75            let mut bytes = BytesMut::zeroed(1024);
76            let (new_state, done) = match std::io::stdin().read(&mut bytes) {
77                Ok(0) => (StdinState::Closed, true),
78                Ok(nbytes) => {
79                    bytes.truncate(nbytes);
80                    (StdinState::Data(bytes), false)
81                }
82                Err(e) => (StdinState::Error(e), true),
83            };
84
85            // After the blocking read completes the state should not have been
86            // tampered with.
87            debug_assert!(matches!(
88                *state.state.lock().unwrap(),
89                StdinState::ReadRequested
90            ));
91            *state.state.lock().unwrap() = new_state;
92            state.read_completed.notify_waiters();
93            if done {
94                break;
95            }
96        }
97    });
98
99    GlobalStdin::default()
100}
101
102/// Only public interface is the [`InputStream`] impl.
103#[derive(Clone)]
104pub struct Stdin;
105
106/// Returns a stream that represents the host's standard input.
107///
108/// Suitable for passing to
109/// [`WasiCtxBuilder::stdin`](crate::p2::WasiCtxBuilder::stdin).
110pub fn stdin() -> Stdin {
111    Stdin
112}
113
114impl StdinStream for Stdin {
115    fn stream(&self) -> Box<dyn InputStream> {
116        Box::new(Stdin)
117    }
118}
119
120impl IsTerminal for Stdin {
121    fn is_terminal(&self) -> bool {
122        std::io::stdin().is_terminal()
123    }
124}
125
126#[async_trait::async_trait]
127impl InputStream for Stdin {
128    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
129        let g = GlobalStdin::get();
130        let mut locked = g.state.lock().unwrap();
131        match mem::replace(&mut *locked, StdinState::ReadRequested) {
132            StdinState::ReadNotRequested => {
133                g.read_requested.notify_one();
134                Ok(Bytes::new())
135            }
136            StdinState::ReadRequested => Ok(Bytes::new()),
137            StdinState::Data(mut data) => {
138                let size = data.len().min(size);
139                let bytes = data.split_to(size);
140                *locked = if data.is_empty() {
141                    StdinState::ReadNotRequested
142                } else {
143                    StdinState::Data(data)
144                };
145                Ok(bytes.freeze())
146            }
147            StdinState::Error(e) => {
148                *locked = StdinState::Closed;
149                Err(StreamError::LastOperationFailed(e.into()))
150            }
151            StdinState::Closed => {
152                *locked = StdinState::Closed;
153                Err(StreamError::Closed)
154            }
155        }
156    }
157}
158
159#[async_trait::async_trait]
160impl Pollable for Stdin {
161    async fn ready(&mut self) {
162        let g = GlobalStdin::get();
163
164        // Scope the synchronous `state.lock()` to this block which does not
165        // `.await` inside of it.
166        let notified = {
167            let mut locked = g.state.lock().unwrap();
168            match *locked {
169                // If a read isn't requested yet
170                StdinState::ReadNotRequested => {
171                    g.read_requested.notify_one();
172                    *locked = StdinState::ReadRequested;
173                    g.read_completed.notified()
174                }
175                StdinState::ReadRequested => g.read_completed.notified(),
176                StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
177            }
178        };
179
180        notified.await;
181    }
182}