wasmtime_wasi/cli/
worker_thread_stdin.rs1use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37 poll::Pollable,
38 streams::{InputStream, StreamError},
39};
40
41use crate::MAX_READ_SIZE_ALLOC;
42
43impl IsTerminal for tokio::io::Stdin {
45 fn is_terminal(&self) -> bool {
46 std::io::stdin().is_terminal()
47 }
48}
49impl StdinStream for tokio::io::Stdin {
50 fn p2_stream(&self) -> Box<dyn InputStream> {
51 Box::new(WasiStdin)
52 }
53 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
54 Box::new(WasiStdinAsyncRead::Ready)
55 }
56}
57
58impl IsTerminal for std::io::Stdin {
60 fn is_terminal(&self) -> bool {
61 std::io::IsTerminal::is_terminal(self)
62 }
63}
64impl StdinStream for std::io::Stdin {
65 fn p2_stream(&self) -> Box<dyn InputStream> {
66 Box::new(WasiStdin)
67 }
68 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
69 Box::new(WasiStdinAsyncRead::Ready)
70 }
71}
72
73#[derive(Default)]
74struct GlobalStdin {
75 state: Mutex<StdinState>,
76 read_requested: Condvar,
77 read_completed: Notify,
78}
79
80#[derive(Default, Debug)]
81enum StdinState {
82 #[default]
83 ReadNotRequested,
84 ReadRequested(usize),
85 Data(BytesMut),
86 Error(std::io::Error),
87 Closed,
88}
89
90impl GlobalStdin {
91 fn get() -> &'static GlobalStdin {
92 static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
93 STDIN.get_or_init(|| create())
94 }
95}
96
97fn create() -> GlobalStdin {
98 std::thread::spawn(|| {
99 let state = GlobalStdin::get();
100 loop {
101 let mut lock = state.state.lock().unwrap();
104 lock = state
105 .read_requested
106 .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested(_)))
107 .unwrap();
108
109 let size_hint = match *lock {
113 StdinState::ReadRequested(size) => size.min(MAX_READ_SIZE_ALLOC).max(1),
114 _ => unreachable!(),
115 };
116 drop(lock);
117
118 let mut bytes = BytesMut::zeroed(size_hint);
119 let (new_state, done) = match std::io::stdin().read(&mut bytes) {
120 Ok(0) => (StdinState::Closed, true),
121 Ok(nbytes) => {
122 bytes.truncate(nbytes);
123 (StdinState::Data(bytes), false)
124 }
125 Err(e) => (StdinState::Error(e), true),
126 };
127
128 debug_assert!(matches!(
131 *state.state.lock().unwrap(),
132 StdinState::ReadRequested(_)
133 ));
134 let mut lock = state.state.lock().unwrap();
135 *lock = new_state;
136 state.read_completed.notify_waiters();
137 if done {
138 break;
139 }
140 }
141 });
142
143 GlobalStdin::default()
144}
145
146struct WasiStdin;
147
148#[async_trait::async_trait]
149impl InputStream for WasiStdin {
150 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
151 if size == 0 {
152 return Ok(Bytes::new());
153 }
154 let g = GlobalStdin::get();
155 let mut locked = g.state.lock().unwrap();
156 match mem::replace(&mut *locked, StdinState::ReadRequested(size)) {
157 StdinState::ReadNotRequested => {
158 g.read_requested.notify_one();
159 Ok(Bytes::new())
160 }
161 StdinState::ReadRequested(prev_size) => {
162 *locked = StdinState::ReadRequested(prev_size.max(size));
165 Ok(Bytes::new())
166 }
167 StdinState::Data(mut data) => {
168 let size = data.len().min(size);
169 let bytes = data.split_to(size);
170 *locked = if data.is_empty() {
171 StdinState::ReadNotRequested
172 } else {
173 StdinState::Data(data)
174 };
175 Ok(bytes.freeze())
176 }
177 StdinState::Error(e) => {
178 *locked = StdinState::Closed;
179 Err(StreamError::LastOperationFailed(e.into()))
180 }
181 StdinState::Closed => {
182 *locked = StdinState::Closed;
183 Err(StreamError::Closed)
184 }
185 }
186 }
187}
188
189#[async_trait::async_trait]
190impl Pollable for WasiStdin {
191 async fn ready(&mut self) {
192 let g = GlobalStdin::get();
193
194 let notified = {
197 let mut locked = g.state.lock().unwrap();
198 match *locked {
199 StdinState::ReadNotRequested => {
203 g.read_requested.notify_one();
204 *locked = StdinState::ReadRequested(MAX_READ_SIZE_ALLOC);
205 g.read_completed.notified()
206 }
207 StdinState::ReadRequested(_) => g.read_completed.notified(),
208 StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
209 }
210 };
211
212 notified.await;
213 }
214}
215
216enum WasiStdinAsyncRead {
217 Ready,
218 Waiting(Notified<'static>),
219}
220
221impl AsyncRead for WasiStdinAsyncRead {
222 fn poll_read(
223 mut self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 buf: &mut ReadBuf<'_>,
226 ) -> Poll<io::Result<()>> {
227 let g = GlobalStdin::get();
228
229 let mut locked = g.state.lock().unwrap();
236
237 loop {
241 if let Some(notified) = self.as_mut().notified_future() {
244 match notified.poll(cx) {
245 Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
246 Poll::Pending => break Poll::Pending,
247 }
248 }
249
250 assert!(matches!(*self, WasiStdinAsyncRead::Ready));
251
252 match mem::replace(&mut *locked, StdinState::ReadRequested(buf.remaining())) {
255 StdinState::Data(mut data) => {
257 let size = data.len().min(buf.remaining());
258 let bytes = data.split_to(size);
259 *locked = if data.is_empty() {
260 StdinState::ReadNotRequested
261 } else {
262 StdinState::Data(data)
263 };
264 buf.put_slice(&bytes);
265 break Poll::Ready(Ok(()));
266 }
267
268 StdinState::Error(e) => {
271 *locked = StdinState::Closed;
272 break Poll::Ready(Err(e));
273 }
274
275 StdinState::Closed => {
277 *locked = StdinState::Closed;
278 break Poll::Ready(Ok(()));
279 }
280
281 StdinState::ReadNotRequested => {
285 g.read_requested.notify_one();
286 }
287 StdinState::ReadRequested(prev_size) => {
288 *locked = StdinState::ReadRequested(prev_size.max(buf.remaining()));
290 }
291 }
292
293 self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
294 }
295 }
296}
297
298impl WasiStdinAsyncRead {
299 fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
300 unsafe {
304 match self.get_unchecked_mut() {
305 WasiStdinAsyncRead::Ready => None,
306 WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
307 }
308 }
309 }
310}