wasmtime_wasi/p2/
write_stream.rs

1use crate::p2::{OutputStream, Pollable, StreamError};
2use anyhow::anyhow;
3use bytes::Bytes;
4use std::pin::pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll, Waker};
7
8#[derive(Debug)]
9struct WorkerState {
10    alive: bool,
11    items: std::collections::VecDeque<Bytes>,
12    write_budget: usize,
13    flush_pending: bool,
14    error: Option<anyhow::Error>,
15    write_ready_changed: Option<Waker>,
16}
17
18impl WorkerState {
19    fn check_error(&mut self) -> Result<(), StreamError> {
20        if let Some(e) = self.error.take() {
21            return Err(StreamError::LastOperationFailed(e));
22        }
23        if !self.alive {
24            return Err(StreamError::Closed);
25        }
26        Ok(())
27    }
28}
29
30struct Worker {
31    state: Mutex<WorkerState>,
32    new_work: tokio::sync::Notify,
33}
34
35enum Job {
36    Flush,
37    Write(Bytes),
38}
39
40impl Worker {
41    fn new(write_budget: usize) -> Self {
42        Self {
43            state: Mutex::new(WorkerState {
44                alive: true,
45                items: std::collections::VecDeque::new(),
46                write_budget,
47                flush_pending: false,
48                error: None,
49                write_ready_changed: None,
50            }),
51            new_work: tokio::sync::Notify::new(),
52        }
53    }
54    fn check_write(&self) -> Result<usize, StreamError> {
55        let mut state = self.state();
56        if let Err(e) = state.check_error() {
57            return Err(e);
58        }
59
60        if state.flush_pending || state.write_budget == 0 {
61            return Ok(0);
62        }
63
64        Ok(state.write_budget)
65    }
66    fn state(&self) -> std::sync::MutexGuard<'_, WorkerState> {
67        self.state.lock().unwrap()
68    }
69    fn pop(&self) -> Option<Job> {
70        let mut state = self.state();
71        if state.items.is_empty() {
72            if state.flush_pending {
73                return Some(Job::Flush);
74            }
75        } else if let Some(bytes) = state.items.pop_front() {
76            return Some(Job::Write(bytes));
77        }
78
79        None
80    }
81    fn report_error(&self, e: std::io::Error) {
82        let waker = {
83            let mut state = self.state();
84            state.alive = false;
85            state.error = Some(e.into());
86            state.flush_pending = false;
87            state.write_ready_changed.take()
88        };
89        if let Some(waker) = waker {
90            waker.wake();
91        }
92    }
93    async fn work<T: tokio::io::AsyncWrite + Send + 'static>(&self, writer: T) {
94        use tokio::io::AsyncWriteExt;
95        let mut writer = pin!(writer);
96        loop {
97            while let Some(job) = self.pop() {
98                match job {
99                    Job::Flush => {
100                        if let Err(e) = writer.flush().await {
101                            self.report_error(e);
102                            return;
103                        }
104
105                        tracing::debug!("worker marking flush complete");
106                        self.state().flush_pending = false;
107                    }
108
109                    Job::Write(mut bytes) => {
110                        tracing::debug!("worker writing: {bytes:?}");
111                        let len = bytes.len();
112                        match writer.write_all_buf(&mut bytes).await {
113                            Err(e) => {
114                                self.report_error(e);
115                                return;
116                            }
117                            Ok(_) => {
118                                self.state().write_budget += len;
119                            }
120                        }
121                    }
122                }
123
124                let waker = self.state().write_ready_changed.take();
125                if let Some(waker) = waker {
126                    waker.wake();
127                }
128            }
129            self.new_work.notified().await;
130        }
131    }
132}
133
134/// Provides a [`OutputStream`] impl from a [`tokio::io::AsyncWrite`] impl
135pub struct AsyncWriteStream {
136    worker: Arc<Worker>,
137    join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
138}
139
140impl AsyncWriteStream {
141    /// Create a [`AsyncWriteStream`]. In order to use the [`OutputStream`] impl
142    /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`].
143    pub fn new<T: tokio::io::AsyncWrite + Send + 'static>(write_budget: usize, writer: T) -> Self {
144        let worker = Arc::new(Worker::new(write_budget));
145
146        let w = Arc::clone(&worker);
147        let join_handle = crate::runtime::spawn(async move { w.work(writer).await });
148
149        AsyncWriteStream {
150            worker,
151            join_handle: Some(join_handle),
152        }
153    }
154
155    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
156        let mut state = self.worker.state();
157        if state.error.is_some() || !state.alive || (!state.flush_pending && state.write_budget > 0)
158        {
159            return Poll::Ready(());
160        }
161        state.write_ready_changed = Some(cx.waker().clone());
162        Poll::Pending
163    }
164}
165
166#[async_trait::async_trait]
167impl OutputStream for AsyncWriteStream {
168    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
169        let mut state = self.worker.state();
170        state.check_error()?;
171        if state.flush_pending {
172            return Err(StreamError::Trap(anyhow!(
173                "write not permitted while flush pending"
174            )));
175        }
176        match state.write_budget.checked_sub(bytes.len()) {
177            Some(remaining_budget) => {
178                state.write_budget = remaining_budget;
179                state.items.push_back(bytes);
180            }
181            None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
182        }
183        drop(state);
184        self.worker.new_work.notify_one();
185        Ok(())
186    }
187    fn flush(&mut self) -> Result<(), StreamError> {
188        let mut state = self.worker.state();
189        state.check_error()?;
190
191        state.flush_pending = true;
192        self.worker.new_work.notify_one();
193
194        Ok(())
195    }
196
197    fn check_write(&mut self) -> Result<usize, StreamError> {
198        self.worker.check_write()
199    }
200
201    async fn cancel(&mut self) {
202        match self.join_handle.take() {
203            Some(task) => _ = task.cancel().await,
204            None => {}
205        }
206    }
207}
208#[async_trait::async_trait]
209impl Pollable for AsyncWriteStream {
210    async fn ready(&mut self) {
211        std::future::poll_fn(|cx| self.poll_ready(cx)).await
212    }
213}