wasmtime_wasi/cli/
worker_thread_stdin.rs1use crate::cli::{IsTerminal, StdinStream};
27use bytes::{Bytes, BytesMut};
28use std::io::Read;
29use std::mem;
30use std::pin::Pin;
31use std::sync::{Condvar, Mutex, OnceLock};
32use std::task::{Context, Poll};
33use tokio::io::{self, AsyncRead, ReadBuf};
34use tokio::sync::Notify;
35use tokio::sync::futures::Notified;
36use wasmtime_wasi_io::{
37 poll::Pollable,
38 streams::{InputStream, StreamError},
39};
40
41impl IsTerminal for tokio::io::Stdin {
43 fn is_terminal(&self) -> bool {
44 std::io::stdin().is_terminal()
45 }
46}
47impl StdinStream for tokio::io::Stdin {
48 fn p2_stream(&self) -> Box<dyn InputStream> {
49 Box::new(WasiStdin)
50 }
51 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
52 Box::new(WasiStdinAsyncRead::Ready)
53 }
54}
55
56impl IsTerminal for std::io::Stdin {
58 fn is_terminal(&self) -> bool {
59 std::io::IsTerminal::is_terminal(self)
60 }
61}
62impl StdinStream for std::io::Stdin {
63 fn p2_stream(&self) -> Box<dyn InputStream> {
64 Box::new(WasiStdin)
65 }
66 fn async_stream(&self) -> Box<dyn AsyncRead + Send + Sync> {
67 Box::new(WasiStdinAsyncRead::Ready)
68 }
69}
70
71#[derive(Default)]
72struct GlobalStdin {
73 state: Mutex<StdinState>,
74 read_requested: Condvar,
75 read_completed: Notify,
76}
77
78#[derive(Default, Debug)]
79enum StdinState {
80 #[default]
81 ReadNotRequested,
82 ReadRequested,
83 Data(BytesMut),
84 Error(std::io::Error),
85 Closed,
86}
87
88impl GlobalStdin {
89 fn get() -> &'static GlobalStdin {
90 static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
91 STDIN.get_or_init(|| create())
92 }
93}
94
95fn create() -> GlobalStdin {
96 std::thread::spawn(|| {
97 let state = GlobalStdin::get();
98 loop {
99 let mut lock = state.state.lock().unwrap();
102 lock = state
103 .read_requested
104 .wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
105 .unwrap();
106 drop(lock);
107
108 let mut bytes = BytesMut::zeroed(1024);
109 let (new_state, done) = match std::io::stdin().read(&mut bytes) {
110 Ok(0) => (StdinState::Closed, true),
111 Ok(nbytes) => {
112 bytes.truncate(nbytes);
113 (StdinState::Data(bytes), false)
114 }
115 Err(e) => (StdinState::Error(e), true),
116 };
117
118 debug_assert!(matches!(
121 *state.state.lock().unwrap(),
122 StdinState::ReadRequested
123 ));
124 *state.state.lock().unwrap() = new_state;
125 state.read_completed.notify_waiters();
126 if done {
127 break;
128 }
129 }
130 });
131
132 GlobalStdin::default()
133}
134
135struct WasiStdin;
136
137#[async_trait::async_trait]
138impl InputStream for WasiStdin {
139 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
140 let g = GlobalStdin::get();
141 let mut locked = g.state.lock().unwrap();
142 match mem::replace(&mut *locked, StdinState::ReadRequested) {
143 StdinState::ReadNotRequested => {
144 g.read_requested.notify_one();
145 Ok(Bytes::new())
146 }
147 StdinState::ReadRequested => Ok(Bytes::new()),
148 StdinState::Data(mut data) => {
149 let size = data.len().min(size);
150 let bytes = data.split_to(size);
151 *locked = if data.is_empty() {
152 StdinState::ReadNotRequested
153 } else {
154 StdinState::Data(data)
155 };
156 Ok(bytes.freeze())
157 }
158 StdinState::Error(e) => {
159 *locked = StdinState::Closed;
160 Err(StreamError::LastOperationFailed(e.into()))
161 }
162 StdinState::Closed => {
163 *locked = StdinState::Closed;
164 Err(StreamError::Closed)
165 }
166 }
167 }
168}
169
170#[async_trait::async_trait]
171impl Pollable for WasiStdin {
172 async fn ready(&mut self) {
173 let g = GlobalStdin::get();
174
175 let notified = {
178 let mut locked = g.state.lock().unwrap();
179 match *locked {
180 StdinState::ReadNotRequested => {
182 g.read_requested.notify_one();
183 *locked = StdinState::ReadRequested;
184 g.read_completed.notified()
185 }
186 StdinState::ReadRequested => g.read_completed.notified(),
187 StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return,
188 }
189 };
190
191 notified.await;
192 }
193}
194
195enum WasiStdinAsyncRead {
196 Ready,
197 Waiting(Notified<'static>),
198}
199
200impl AsyncRead for WasiStdinAsyncRead {
201 fn poll_read(
202 mut self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &mut ReadBuf<'_>,
205 ) -> Poll<io::Result<()>> {
206 let g = GlobalStdin::get();
207
208 loop {
212 if let Some(notified) = self.as_mut().notified_future() {
215 match notified.poll(cx) {
216 Poll::Ready(()) => self.set(WasiStdinAsyncRead::Ready),
217 Poll::Pending => break Poll::Pending,
218 }
219 }
220
221 assert!(matches!(*self, WasiStdinAsyncRead::Ready));
222
223 let mut locked = g.state.lock().unwrap();
226 match mem::replace(&mut *locked, StdinState::ReadRequested) {
227 StdinState::Data(mut data) => {
229 let size = data.len().min(buf.remaining());
230 let bytes = data.split_to(size);
231 *locked = if data.is_empty() {
232 StdinState::ReadNotRequested
233 } else {
234 StdinState::Data(data)
235 };
236 buf.put_slice(&bytes);
237 break Poll::Ready(Ok(()));
238 }
239
240 StdinState::Error(e) => {
243 *locked = StdinState::Closed;
244 break Poll::Ready(Err(e));
245 }
246
247 StdinState::Closed => {
249 *locked = StdinState::Closed;
250 break Poll::Ready(Ok(()));
251 }
252
253 StdinState::ReadNotRequested => {
257 g.read_requested.notify_one();
258 }
259 StdinState::ReadRequested => {}
260 }
261
262 self.set(WasiStdinAsyncRead::Waiting(g.read_completed.notified()));
263
264 drop(locked);
268 }
269 }
270}
271
272impl WasiStdinAsyncRead {
273 fn notified_future(self: Pin<&mut Self>) -> Option<Pin<&mut Notified<'static>>> {
274 unsafe {
278 match self.get_unchecked_mut() {
279 WasiStdinAsyncRead::Ready => None,
280 WasiStdinAsyncRead::Waiting(notified) => Some(Pin::new_unchecked(notified)),
281 }
282 }
283 }
284}