wasmtime_wasi/cli/
worker_thread_stdin.rs1use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37 poll::Pollable,
38 streams::{InputStream, StreamError},
39};
40
41impl IsTerminal for tokio::io::Stdin {
43 fn is_terminal(&self) -> bool {
44 std::io::stdin().is_terminal()
45 }
46}
47impl StdinStream for tokio::io::Stdin {
48 fn p2_stream(&self) -> Box<dyn InputStream> {
49 Box::new(WasiStdin)
50 }
51 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
52 Box::new(WasiStdinAsyncRead::Ready)
53 }
54}
55
56impl IsTerminal for std::io::Stdin {
58 fn is_terminal(&self) -> bool {
59 std::io::IsTerminal::is_terminal(self)
60 }
61}
62impl StdinStream for std::io::Stdin {
63 fn p2_stream(&self) -> Box<dyn InputStream> {
64 Box::new(WasiStdin)
65 }
66 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
67 Box::new(WasiStdinAsyncRead::Ready)
68 }
69}
70
71#[derive(Default)]
72struct GlobalStdin {
73 state: Mutex<StdinState>,
74 read_requested: Condvar,
75 read_completed: Notify,
76}
77
78#[derive(Default, Debug)]
79enum StdinState {
80 #[default]
81 ReadNotRequested,
82 ReadRequested,
83 Data(BytesMut),
84 Error(std::io::Error),
85 Closed,
86}
87
88impl GlobalStdin {
89 fn get() -> &'static GlobalStdin {
90 static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
91 STDIN.get_or_init(|| create())
92 }
93}
94
95fn create() -> GlobalStdin {
96 std::thread::spawn(|| {
97 let state = GlobalStdin::get();
98 loop {
99 let mut lock = state.state.lock().unwrap();
102 lock = state
103 .read_requested
104 .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
105 .unwrap();
106 drop(lock);
107
108 let mut bytes = BytesMut::zeroed(1024);
109 let (new_state, done) = match std::io::stdin().read(&mut bytes) {
110 Ok(0) => (StdinState::Closed, true),
111 Ok(nbytes) => {
112 bytes.truncate(nbytes);
113 (StdinState::Data(bytes), false)
114 }
115 Err(e) => (StdinState::Error(e), true),
116 };
117
118 debug_assert!(matches!(
121 *state.state.lock().unwrap(),
122 StdinState::ReadRequested
123 ));
124 let mut lock = state.state.lock().unwrap();
125 *lock = new_state;
126 state.read_completed.notify_waiters();
127 if done {
128 break;
129 }
130 }
131 });
132
133 GlobalStdin::default()
134}
135
136struct WasiStdin;
137
138#[async_trait::async_trait]
139impl InputStream for WasiStdin {
140 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
141 let g = GlobalStdin::get();
142 let mut locked = g.state.lock().unwrap();
143 match mem::replace(&mut *locked, StdinState::ReadRequested) {
144 StdinState::ReadNotRequested => {
145 g.read_requested.notify_one();
146 Ok(Bytes::new())
147 }
148 StdinState::ReadRequested => Ok(Bytes::new()),
149 StdinState::Data(mut data) => {
150 let size = data.len().min(size);
151 let bytes = data.split_to(size);
152 *locked = if data.is_empty() {
153 StdinState::ReadNotRequested
154 } else {
155 StdinState::Data(data)
156 };
157 Ok(bytes.freeze())
158 }
159 StdinState::Error(e) => {
160 *locked = StdinState::Closed;
161 Err(StreamError::LastOperationFailed(e.into()))
162 }
163 StdinState::Closed => {
164 *locked = StdinState::Closed;
165 Err(StreamError::Closed)
166 }
167 }
168 }
169}
170
171#[async_trait::async_trait]
172impl Pollable for WasiStdin {
173 async fn ready(&mut self) {
174 let g = GlobalStdin::get();
175
176 let notified = {
179 let mut locked = g.state.lock().unwrap();
180 match *locked {
181 StdinState::ReadNotRequested => {
183 g.read_requested.notify_one();
184 *locked = StdinState::ReadRequested;
185 g.read_completed.notified()
186 }
187 StdinState::ReadRequested => g.read_completed.notified(),
188 StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
189 }
190 };
191
192 notified.await;
193 }
194}
195
196enum WasiStdinAsyncRead {
197 Ready,
198 Waiting(Notified<'static>),
199}
200
201impl AsyncRead for WasiStdinAsyncRead {
202 fn poll_read(
203 mut self: Pin<&mut Self>,
204 cx: &mut Context<'_>,
205 buf: &mut ReadBuf<'_>,
206 ) -> Poll<io::Result<()>> {
207 let g = GlobalStdin::get();
208
209 let mut locked = g.state.lock().unwrap();
216
217 loop {
221 if let Some(notified) = self.as_mut().notified_future() {
224 match notified.poll(cx) {
225 Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
226 Poll::Pending => break Poll::Pending,
227 }
228 }
229
230 assert!(matches!(*self, WasiStdinAsyncRead::Ready));
231
232 match mem::replace(&mut *locked, StdinState::ReadRequested) {
235 StdinState::Data(mut data) => {
237 let size = data.len().min(buf.remaining());
238 let bytes = data.split_to(size);
239 *locked = if data.is_empty() {
240 StdinState::ReadNotRequested
241 } else {
242 StdinState::Data(data)
243 };
244 buf.put_slice(&bytes);
245 break Poll::Ready(Ok(()));
246 }
247
248 StdinState::Error(e) => {
251 *locked = StdinState::Closed;
252 break Poll::Ready(Err(e));
253 }
254
255 StdinState::Closed => {
257 *locked = StdinState::Closed;
258 break Poll::Ready(Ok(()));
259 }
260
261 StdinState::ReadNotRequested => {
265 g.read_requested.notify_one();
266 }
267 StdinState::ReadRequested => {}
268 }
269
270 self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
271 }
272 }
273}
274
275impl WasiStdinAsyncRead {
276 fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
277 unsafe {
281 match self.get_unchecked_mut() {
282 WasiStdinAsyncRead::Ready => None,
283 WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
284 }
285 }
286 }
287}