wasmtime_wasi/p2/
write_stream.rs1use crate::p2::{OutputStream, Pollable, StreamError};
2use anyhow::anyhow;
3use bytes::Bytes;
4use std::pin::pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll, Waker};
7
8#[derive(Debug)]
9struct WorkerState {
10 alive: bool,
11 items: std::collections::VecDeque<Bytes>,
12 write_budget: usize,
13 flush_pending: bool,
14 error: Option<anyhow::Error>,
15 write_ready_changed: Option<Waker>,
16}
17
18impl WorkerState {
19 fn check_error(&mut self) -> Result<(), StreamError> {
20 if let Some(e) = self.error.take() {
21 return Err(StreamError::LastOperationFailed(e));
22 }
23 if !self.alive {
24 return Err(StreamError::Closed);
25 }
26 Ok(())
27 }
28}
29
30struct Worker {
31 state: Mutex<WorkerState>,
32 new_work: tokio::sync::Notify,
33}
34
35enum Job {
36 Flush,
37 Write(Bytes),
38}
39
40impl Worker {
41 fn new(write_budget: usize) -> Self {
42 Self {
43 state: Mutex::new(WorkerState {
44 alive: true,
45 items: std::collections::VecDeque::new(),
46 write_budget,
47 flush_pending: false,
48 error: None,
49 write_ready_changed: None,
50 }),
51 new_work: tokio::sync::Notify::new(),
52 }
53 }
54 fn check_write(&self) -> Result<usize, StreamError> {
55 let mut state = self.state();
56 if let Err(e) = state.check_error() {
57 return Err(e);
58 }
59
60 if state.flush_pending || state.write_budget == 0 {
61 return Ok(0);
62 }
63
64 Ok(state.write_budget)
65 }
66 fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> {
67 self.state.lock().unwrap()
68 }
69 fn pop(&self) -> Option<Job> {
70 let mut state = self.state();
71 if state.items.is_empty() {
72 if state.flush_pending {
73 return Some(Job::Flush);
74 }
75 } else if let Some(bytes) = state.items.pop_front() {
76 return Some(Job::Write(bytes));
77 }
78
79 None
80 }
81 fn report_error(&self, e: std::io::Error) {
82 let waker = {
83 let mut state = self.state();
84 state.alive = false;
85 state.error = Some(e.into());
86 state.flush_pending = false;
87 state.write_ready_changed.take()
88 };
89 if let Some(waker) = waker {
90 waker.wake();
91 }
92 }
93 async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) {
94 use tokio::io::AsyncWriteExt;
95 let mut writer = pin!(writer);
96 loop {
97 while let Some(job) = self.pop() {
98 match job {
99 Job::Flush => {
100 if let Err(e) = writer.flush().await {
101 self.report_error(e);
102 return;
103 }
104
105 tracing::debug!("worker marking flush complete");
106 self.state().flush_pending = false;
107 }
108
109 Job::Write(mut bytes) => {
110 tracing::debug!("worker writing: {bytes:?}");
111 let len = bytes.len();
112 match writer.write_all_buf(&mut bytes).await {
113 Err(e) => {
114 self.report_error(e);
115 return;
116 }
117 Ok(_) => {
118 self.state().write_budget += len;
119 }
120 }
121 }
122 }
123
124 let waker = self.state().write_ready_changed.take();
125 if let Some(waker) = waker {
126 waker.wake();
127 }
128 }
129 self.new_work.notified().await;
130 }
131 }
132}
133
134pub struct AsyncWriteStream {
136 worker: Arc<Worker>,
137 join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
138}
139
140impl AsyncWriteStream {
141 pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self {
144 let worker = Arc::new(Worker::new(write_budget));
145
146 let w = Arc::clone(&worker);
147 let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
148
149 AsyncWriteStream {
150 worker,
151 join_handle: Some(join_handle),
152 }
153 }
154
155 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
156 let mut state = self.worker.state();
157 if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0)
158 {
159 return Poll::Ready(());
160 }
161 state.write_ready_changed = Some(cx.waker().clone());
162 Poll::Pending
163 }
164}
165
166#[async_trait::async_trait]
167impl OutputStream for AsyncWriteStream {
168 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
169 let mut state = self.worker.state();
170 state.check_error()?;
171 if state.flush_pending {
172 return Err(StreamError::Trap(anyhow!(
173 "write not permitted while flush pending"
174 )));
175 }
176 match state.write_budget.checked_sub(bytes.len()) {
177 Some(remaining_budget) => {
178 state.write_budget = remaining_budget;
179 state.items.push_back(bytes);
180 }
181 None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
182 }
183 drop(state);
184 self.worker.new_work.notify_one();
185 Ok(())
186 }
187 fn flush(&mut self) -> Result<(), StreamError> {
188 let mut state = self.worker.state();
189 state.check_error()?;
190
191 state.flush_pending = true;
192 self.worker.new_work.notify_one();
193
194 Ok(())
195 }
196
197 fn check_write(&mut self) -> Result<usize, StreamError> {
198 self.worker.check_write()
199 }
200
201 async fn cancel(&mut self) {
202 match self.join_handle.take() {
203 Some(task) => _ = task.cancel().await,
204 None => {}
205 }
206 }
207}
208#[async_trait::async_trait]
209impl Pollable for AsyncWriteStream {
210 async fn ready(&mut self) {
211 std::future::poll_fn(|cx| self.poll_ready(cx)).await
212 }
213}