Skip to main content

wasmtime_wasi/p2/
pipe.rs

1//! Virtual pipes.
2//!
3//! These types provide easy implementations of `WasiFile` that mimic much of the behavior of Unix
4//! pipes. These are particularly helpful for redirecting WASI stdio handles to destinations other
5//! than OS files.
6//!
7//! Some convenience constructors are included for common backing types like `Vec<u8>` and `String`,
8//! but the virtual pipes can be instantiated with any `Read` or `Write` type.
9//!
10use bytes::Bytes;
11use std::pin::{Pin, pin};
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use tokio::io::{self, AsyncRead, AsyncWrite};
15use tokio::sync::mpsc;
16use wasmtime::format_err;
17use wasmtime_wasi_io::{
18    poll::Pollable,
19    streams::{InputStream, OutputStream, StreamError},
20};
21
22pub use crate::p2::write_stream::AsyncWriteStream;
23
24#[derive(Debug, Clone)]
25pub struct MemoryInputPipe {
26    buffer: Arc<Mutex<Bytes>>,
27}
28
29impl MemoryInputPipe {
30    pub fn new(bytes: impl Into<Bytes>) -> Self {
31        Self {
32            buffer: Arc::new(Mutex::new(bytes.into())),
33        }
34    }
35
36    pub fn is_empty(&self) -> bool {
37        self.buffer.lock().unwrap().is_empty()
38    }
39}
40
41#[async_trait::async_trait]
42impl InputStream for MemoryInputPipe {
43    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
44        let mut buffer = self.buffer.lock().unwrap();
45        if buffer.is_empty() {
46            return Err(StreamError::Closed);
47        }
48
49        let size = size.min(buffer.len());
50        let read = buffer.split_to(size);
51        Ok(read)
52    }
53}
54
55#[async_trait::async_trait]
56impl Pollable for MemoryInputPipe {
57    async fn ready(&mut self) {}
58}
59
60impl AsyncRead for MemoryInputPipe {
61    fn poll_read(
62        self: Pin<&mut Self>,
63        _cx: &mut Context<'_>,
64        buf: &mut io::ReadBuf<'_>,
65    ) -> Poll<io::Result<()>> {
66        let mut buffer = self.buffer.lock().unwrap();
67        let size = buf.remaining().min(buffer.len());
68        let read = buffer.split_to(size);
69        buf.put_slice(&read);
70        Poll::Ready(Ok(()))
71    }
72}
73
74#[derive(Debug, Clone)]
75pub struct MemoryOutputPipe {
76    capacity: usize,
77    buffer: Arc<Mutex<bytes::BytesMut>>,
78}
79
80impl MemoryOutputPipe {
81    pub fn new(capacity: usize) -> Self {
82        MemoryOutputPipe {
83            capacity,
84            buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())),
85        }
86    }
87
88    pub fn contents(&self) -> bytes::Bytes {
89        self.buffer.lock().unwrap().clone().freeze()
90    }
91
92    pub fn try_into_inner(self) -> Option<bytes::BytesMut> {
93        std::sync::Arc::into_inner(self.buffer).map(|m| m.into_inner().unwrap())
94    }
95}
96
97#[async_trait::async_trait]
98impl OutputStream for MemoryOutputPipe {
99    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
100        let mut buf = self.buffer.lock().unwrap();
101        if bytes.len() > self.capacity - buf.len() {
102            return Err(StreamError::Trap(format_err!(
103                "write beyond capacity of MemoryOutputPipe"
104            )));
105        }
106        buf.extend_from_slice(bytes.as_ref());
107        // Always ready for writing
108        Ok(())
109    }
110    fn flush(&mut self) -> Result<(), StreamError> {
111        // This stream is always flushed
112        Ok(())
113    }
114    fn check_write(&mut self) -> Result<usize, StreamError> {
115        let consumed = self.buffer.lock().unwrap().len();
116        if consumed < self.capacity {
117            Ok(self.capacity - consumed)
118        } else {
119            // Since the buffer is full, no more bytes will ever be written
120            Err(StreamError::Closed)
121        }
122    }
123}
124
125#[async_trait::async_trait]
126impl Pollable for MemoryOutputPipe {
127    async fn ready(&mut self) {}
128}
129
130impl AsyncWrite for MemoryOutputPipe {
131    fn poll_write(
132        self: Pin<&mut Self>,
133        _cx: &mut Context<'_>,
134        buf: &[u8],
135    ) -> Poll<io::Result<usize>> {
136        let mut buffer = self.buffer.lock().unwrap();
137        let amt = buf.len().min(self.capacity - buffer.len());
138        buffer.extend_from_slice(&buf[..amt]);
139        Poll::Ready(Ok(amt))
140    }
141    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142        Poll::Ready(Ok(()))
143    }
144    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        Poll::Ready(Ok(()))
146    }
147}
148
149/// Provides a [`InputStream`] impl from a [`tokio::io::AsyncRead`] impl
150pub struct AsyncReadStream {
151    closed: bool,
152    buffer: Option<Result<Bytes, StreamError>>,
153    receiver: mpsc::Receiver<Result<Bytes, StreamError>>,
154    join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
155}
156
157impl AsyncReadStream {
158    /// Create a [`AsyncReadStream`]. In order to use the [`InputStream`] impl
159    /// provided by this struct, the argument must impl [`tokio::io::AsyncRead`].
160    pub fn new<T: AsyncRead + Send + 'static>(reader: T) -> Self {
161        let (sender, receiver) = mpsc::channel(1);
162        let join_handle = crate::runtime::spawn(async move {
163            let mut reader = pin!(reader);
164            loop {
165                use tokio::io::AsyncReadExt;
166                let mut buf = bytes::BytesMut::with_capacity(crate::MAX_READ_SIZE_ALLOC);
167                let sent = match reader.read_buf(&mut buf).await {
168                    Ok(nbytes) if nbytes == 0 => sender.send(Err(StreamError::Closed)).await,
169                    Ok(_) => sender.send(Ok(buf.freeze())).await,
170                    Err(e) => {
171                        sender
172                            .send(Err(StreamError::LastOperationFailed(e.into())))
173                            .await
174                    }
175                };
176                if sent.is_err() {
177                    // no more receiver - stop trying to read
178                    break;
179                }
180            }
181        });
182        AsyncReadStream {
183            closed: false,
184            buffer: None,
185            receiver,
186            join_handle: Some(join_handle),
187        }
188    }
189    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
190        if self.buffer.is_some() || self.closed {
191            return Poll::Ready(());
192        }
193        match self.receiver.poll_recv(cx) {
194            Poll::Ready(Some(res)) => {
195                self.buffer = Some(res);
196                Poll::Ready(())
197            }
198            Poll::Ready(None) => {
199                panic!("no more sender for an open AsyncReadStream - should be impossible")
200            }
201            Poll::Pending => Poll::Pending,
202        }
203    }
204}
205
206#[async_trait::async_trait]
207impl InputStream for AsyncReadStream {
208    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
209        use mpsc::error::TryRecvError;
210
211        match self.buffer.take() {
212            Some(Ok(mut bytes)) => {
213                // TODO: de-duplicate the buffer management with the case below
214                let len = bytes.len().min(size);
215                let rest = bytes.split_off(len);
216                if !rest.is_empty() {
217                    self.buffer = Some(Ok(rest));
218                }
219                return Ok(bytes);
220            }
221            Some(Err(e)) => {
222                self.closed = true;
223                return Err(e);
224            }
225            None => {}
226        }
227
228        match self.receiver.try_recv() {
229            Ok(Ok(mut bytes)) => {
230                let len = bytes.len().min(size);
231                let rest = bytes.split_off(len);
232                if !rest.is_empty() {
233                    self.buffer = Some(Ok(rest));
234                }
235
236                Ok(bytes)
237            }
238            Ok(Err(e)) => {
239                self.closed = true;
240                Err(e)
241            }
242            Err(TryRecvError::Empty) => Ok(Bytes::new()),
243            Err(TryRecvError::Disconnected) => Err(StreamError::Trap(format_err!(
244                "AsyncReadStream sender died - should be impossible"
245            ))),
246        }
247    }
248
249    async fn cancel(&mut self) {
250        match self.join_handle.take() {
251            Some(task) => _ = task.cancel().await,
252            None => {}
253        }
254    }
255}
256
257#[async_trait::async_trait]
258impl Pollable for AsyncReadStream {
259    async fn ready(&mut self) {
260        std::future::poll_fn(|cx| self.poll_ready(cx)).await
261    }
262}
263
264/// An output stream that consumes all input written to it, and is always ready.
265#[derive(Copy, Clone)]
266pub struct SinkOutputStream;
267
268#[async_trait::async_trait]
269impl OutputStream for SinkOutputStream {
270    fn write(&mut self, _buf: Bytes) -> Result<(), StreamError> {
271        Ok(())
272    }
273    fn flush(&mut self) -> Result<(), StreamError> {
274        // This stream is always flushed
275        Ok(())
276    }
277
278    fn check_write(&mut self) -> Result<usize, StreamError> {
279        // This stream is always ready for writing.
280        Ok(usize::MAX)
281    }
282}
283
284#[async_trait::async_trait]
285impl Pollable for SinkOutputStream {
286    async fn ready(&mut self) {}
287}
288
289/// A stream that is ready immediately, but will always report that it's closed.
290#[derive(Copy, Clone)]
291pub struct ClosedInputStream;
292
293#[async_trait::async_trait]
294impl InputStream for ClosedInputStream {
295    fn read(&mut self, _size: usize) -> Result<Bytes, StreamError> {
296        Err(StreamError::Closed)
297    }
298}
299
300#[async_trait::async_trait]
301impl Pollable for ClosedInputStream {
302    async fn ready(&mut self) {}
303}
304
305/// An output stream that is always closed.
306#[derive(Copy, Clone)]
307pub struct ClosedOutputStream;
308
309#[async_trait::async_trait]
310impl OutputStream for ClosedOutputStream {
311    fn write(&mut self, _: Bytes) -> Result<(), StreamError> {
312        Err(StreamError::Closed)
313    }
314    fn flush(&mut self) -> Result<(), StreamError> {
315        Err(StreamError::Closed)
316    }
317
318    fn check_write(&mut self) -> Result<usize, StreamError> {
319        Err(StreamError::Closed)
320    }
321}
322
323#[async_trait::async_trait]
324impl Pollable for ClosedOutputStream {
325    async fn ready(&mut self) {}
326}
327
328#[cfg(test)]
329mod test {
330    use super::*;
331    use std::time::Duration;
332    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
333
334    // This is a gross way to handle CI running under qemu for non-x86 architectures.
335    #[cfg(not(target_arch = "x86_64"))]
336    const TEST_ITERATIONS: usize = 10;
337
338    #[cfg(target_arch = "x86_64")]
339    const TEST_ITERATIONS: usize = 100;
340
341    async fn resolves_immediately<F, O>(fut: F) -> O
342    where
343        F: futures::Future<Output = O>,
344    {
345        // The input `fut` should resolve immediately, but in case it
346        // accidentally doesn't don't hang the test indefinitely. Provide a
347        // generous timeout to account for CI sensitivity and various systems.
348        tokio::time::timeout(Duration::from_secs(2), fut)
349            .await
350            .expect("operation timed out")
351    }
352
353    async fn never_resolves<F: futures::Future>(fut: F) {
354        // The input `fut` should never resolve, so only give it a small window
355        // of budget before we time out. If `fut` is actually resolved this
356        // should show up as a flaky test.
357        tokio::time::timeout(Duration::from_millis(10), fut)
358            .await
359            .err()
360            .expect("operation should time out");
361    }
362
363    pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) {
364        let (a, b) = tokio::io::duplex(size);
365        let (_read_half, write_half) = tokio::io::split(a);
366        let (read_half, _write_half) = tokio::io::split(b);
367        (read_half, write_half)
368    }
369
370    #[test_log::test(tokio::test(flavor = "multi_thread"))]
371    async fn empty_read_stream() {
372        let mut reader = AsyncReadStream::new(tokio::io::empty());
373
374        // In a multi-threaded context, the value of state is not deterministic -- the spawned
375        // reader task may run on a different thread.
376        match reader.read(10) {
377            // The reader task ran before we tried to read, and noticed that the input was empty.
378            Err(StreamError::Closed) => {}
379
380            // The reader task hasn't run yet. Call `ready` to await and fill the buffer.
381            Ok(bs) => {
382                assert!(bs.is_empty());
383                resolves_immediately(reader.ready()).await;
384                assert!(matches!(reader.read(0), Err(StreamError::Closed)));
385            }
386            res => panic!("unexpected: {res:?}"),
387        }
388    }
389
390    #[test_log::test(tokio::test(flavor = "multi_thread"))]
391    async fn infinite_read_stream() {
392        let mut reader = AsyncReadStream::new(tokio::io::repeat(0));
393
394        let bs = reader.read(10).unwrap();
395        if bs.is_empty() {
396            // Reader task hasn't run yet. Call `ready` to await and fill the buffer.
397            resolves_immediately(reader.ready()).await;
398            // Now a read should succeed
399            let bs = reader.read(10).unwrap();
400            assert_eq!(bs.len(), 10);
401        } else {
402            assert_eq!(bs.len(), 10);
403        }
404
405        // Subsequent reads should succeed
406        let bs = reader.read(10).unwrap();
407        assert_eq!(bs.len(), 10);
408
409        // Even 0-length reads should succeed and show its open
410        let bs = reader.read(0).unwrap();
411        assert_eq!(bs.len(), 0);
412    }
413
414    async fn finite_async_reader(contents: &[u8]) -> impl AsyncRead + Send + 'static + use<> {
415        let (r, mut w) = simplex(contents.len());
416        w.write_all(contents).await.unwrap();
417        r
418    }
419
420    #[test_log::test(tokio::test(flavor = "multi_thread"))]
421    async fn finite_read_stream() {
422        let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await);
423
424        let bs = reader.read(123).unwrap();
425        if bs.is_empty() {
426            // Reader task hasn't run yet. Call `ready` to await and fill the buffer.
427            resolves_immediately(reader.ready()).await;
428            // Now a read should succeed
429            let bs = reader.read(123).unwrap();
430            assert_eq!(bs.len(), 123);
431        } else {
432            assert_eq!(bs.len(), 123);
433        }
434
435        // The AsyncRead's should be empty now, but we have a race where the reader task hasn't
436        // yet send that to the AsyncReadStream.
437        match reader.read(0) {
438            Err(StreamError::Closed) => {} // Correct!
439            Ok(bs) => {
440                assert!(bs.is_empty());
441                // Need to await to give this side time to catch up
442                resolves_immediately(reader.ready()).await;
443                // Now a read should show closed
444                assert!(matches!(reader.read(0), Err(StreamError::Closed)));
445            }
446            res => panic!("unexpected: {res:?}"),
447        }
448    }
449
450    #[test_log::test(tokio::test(flavor = "multi_thread"))]
451    // Test that you can write items into the stream, and they get read out in the order they were
452    // written, with the proper indications of readiness for reading:
453    async fn multiple_chunks_read_stream() {
454        let (r, mut w) = simplex(1024);
455        let mut reader = AsyncReadStream::new(r);
456
457        w.write_all(&[123]).await.unwrap();
458
459        let bs = reader.read(1).unwrap();
460        if bs.is_empty() {
461            // Reader task hasn't run yet. Call `ready` to await and fill the buffer.
462            resolves_immediately(reader.ready()).await;
463            // Now a read should succeed
464            let bs = reader.read(1).unwrap();
465            assert_eq!(*bs, [123u8]);
466        } else {
467            assert_eq!(*bs, [123u8]);
468        }
469
470        // The stream should be empty and open now:
471        let bs = reader.read(1).unwrap();
472        assert!(bs.is_empty());
473
474        // We can wait on readiness and it will time out:
475        never_resolves(reader.ready()).await;
476
477        // Still open and empty:
478        let bs = reader.read(1).unwrap();
479        assert!(bs.is_empty());
480
481        // Put something else in the stream:
482        w.write_all(&[45]).await.unwrap();
483
484        // Wait readiness (yes we could possibly win the race and read it out faster, leaving that
485        // out of the test for simplicity)
486        resolves_immediately(reader.ready()).await;
487
488        // read the something else back out:
489        let bs = reader.read(1).unwrap();
490        assert_eq!(*bs, [45u8]);
491
492        // nothing else in there:
493        let bs = reader.read(1).unwrap();
494        assert!(bs.is_empty());
495
496        // We can wait on readiness and it will time out:
497        never_resolves(reader.ready()).await;
498
499        // nothing else in there:
500        let bs = reader.read(1).unwrap();
501        assert!(bs.is_empty());
502
503        // Now close the pipe:
504        drop(w);
505
506        // Wait readiness (yes we could possibly win the race and read it out faster, leaving that
507        // out of the test for simplicity)
508        resolves_immediately(reader.ready()).await;
509
510        // empty and now closed:
511        assert!(matches!(reader.read(1), Err(StreamError::Closed)));
512    }
513
514    #[test_log::test(tokio::test(flavor = "multi_thread"))]
515    // At the moment we are restricting AsyncReadStream from buffering more than 4k. This isn't a
516    // suitable design for all applications, and we will probably make a knob or change the
517    // behavior at some point, but this test shows the behavior as it is implemented:
518    async fn backpressure_read_stream() {
519        let (r, mut w) = simplex(4 * crate::MAX_READ_SIZE_ALLOC); // Make sure this buffer isn't a bottleneck
520        let mut reader = AsyncReadStream::new(r);
521
522        let writer_task = tokio::task::spawn(async move {
523            // Write twice as much as we can buffer up in an AsyncReadStream:
524            w.write_all(&[123; 2 * crate::MAX_READ_SIZE_ALLOC])
525                .await
526                .unwrap();
527            w
528        });
529
530        resolves_immediately(reader.ready()).await;
531
532        // Now we expect the reader task has sent 4k from the stream to the reader.
533        // Try to read out one bigger than the buffer available:
534        let bs = reader.read(crate::MAX_READ_SIZE_ALLOC + 1).unwrap();
535        assert_eq!(bs.len(), crate::MAX_READ_SIZE_ALLOC);
536
537        // Allow the crank to turn more:
538        resolves_immediately(reader.ready()).await;
539
540        // Again we expect the reader task has sent 4k from the stream to the reader.
541        // Try to read out one bigger than the buffer available:
542        let bs = reader.read(crate::MAX_READ_SIZE_ALLOC + 1).unwrap();
543        assert_eq!(bs.len(), crate::MAX_READ_SIZE_ALLOC);
544
545        // The writer task is now finished - join with it:
546        let w = resolves_immediately(writer_task).await;
547
548        // And close the pipe:
549        drop(w);
550
551        // Allow the crank to turn more:
552        resolves_immediately(reader.ready()).await;
553
554        // Now we expect the reader to be empty, and the stream.dropd:
555        assert!(matches!(
556            reader.read(crate::MAX_READ_SIZE_ALLOC + 1),
557            Err(StreamError::Closed)
558        ));
559    }
560
561    #[test_log::test(test_log::test(tokio::test(flavor = "multi_thread")))]
562    async fn sink_write_stream() {
563        let mut writer = AsyncWriteStream::new(2048, tokio::io::sink());
564        let chunk = Bytes::from_static(&[0; 1024]);
565
566        let readiness = resolves_immediately(writer.write_ready())
567            .await
568            .expect("write_ready does not trap");
569        assert_eq!(readiness, 2048);
570        // I can write whatever:
571        writer.write(chunk.clone()).expect("write does not error");
572
573        // This may consume 1k of the buffer:
574        let readiness = resolves_immediately(writer.write_ready())
575            .await
576            .expect("write_ready does not trap");
577        assert!(
578            readiness == 1024 || readiness == 2048,
579            "readiness should be 1024 or 2048, got {readiness}"
580        );
581
582        if readiness == 1024 {
583            writer.write(chunk.clone()).expect("write does not error");
584
585            let readiness = resolves_immediately(writer.write_ready())
586                .await
587                .expect("write_ready does not trap");
588            assert!(
589                readiness == 1024 || readiness == 2048,
590                "readiness should be 1024 or 2048, got {readiness}"
591            );
592        }
593    }
594
595    #[test_log::test(tokio::test(flavor = "multi_thread"))]
596    async fn closed_write_stream() {
597        // Run many times because the test is nondeterministic:
598        for n in 0..TEST_ITERATIONS {
599            closed_write_stream_(n).await
600        }
601    }
602    #[tracing::instrument]
603    async fn closed_write_stream_(n: usize) {
604        let (reader, writer) = simplex(1);
605        let mut writer = AsyncWriteStream::new(1024, writer);
606
607        // Drop the reader to allow the worker to transition to the closed state eventually.
608        drop(reader);
609
610        // First the api is going to report the last operation failed, then subsequently
611        // it will be reported as closed. We set this flag once we see LastOperationFailed.
612        let mut should_be_closed = false;
613
614        // Write some data to the stream to ensure we have data that cannot be flushed.
615        let chunk = Bytes::from_static(&[0; 1]);
616        writer
617            .write(chunk.clone())
618            .expect("first write should succeed");
619
620        // The rest of this test should be valid whether or not we check write readiness:
621        let mut write_ready_res = None;
622        if n % 2 == 0 {
623            let r = resolves_immediately(writer.write_ready()).await;
624            // Check write readiness:
625            match r {
626                // worker hasn't processed write yet:
627                Ok(1023) => {}
628                // worker reports failure:
629                Err(StreamError::LastOperationFailed(_)) => {
630                    tracing::debug!("discovered stream failure in first write_ready");
631                    should_be_closed = true;
632                }
633                r => panic!("unexpected write_ready: {r:?}"),
634            }
635            write_ready_res = Some(r);
636        }
637
638        // When we drop the simplex reader, that causes the simplex writer to return BrokenPipe on
639        // its write. Now that the buffering crank has turned, our next write will give BrokenPipe.
640        let flush_res = writer.flush();
641        match flush_res {
642            // worker reports failure:
643            Err(StreamError::LastOperationFailed(_)) => {
644                tracing::debug!("discovered stream failure trying to flush");
645                assert!(!should_be_closed);
646                should_be_closed = true;
647            }
648            // Already reported failure, now closed
649            Err(StreamError::Closed) => {
650                assert!(
651                    should_be_closed,
652                    "expected a LastOperationFailed before we see Closed. {write_ready_res:?}"
653                );
654            }
655            // Also possible the worker hasn't processed write yet:
656            Ok(()) => {}
657            Err(e) => panic!("unexpected flush error: {e:?} {write_ready_res:?}"),
658        }
659
660        // Waiting for the flush to complete should always indicate that the channel has been
661        // closed.
662        match resolves_immediately(writer.write_ready()).await {
663            // worker reports failure:
664            Err(StreamError::LastOperationFailed(_)) => {
665                tracing::debug!("discovered stream failure trying to flush");
666                assert!(!should_be_closed);
667            }
668            // Already reported failure, now closed
669            Err(StreamError::Closed) => {
670                assert!(should_be_closed);
671            }
672            r => {
673                panic!(
674                    "stream should be reported closed by the end of write_ready after flush, got {r:?}. {write_ready_res:?} {flush_res:?}"
675                )
676            }
677        }
678    }
679
680    #[test_log::test(tokio::test(flavor = "multi_thread"))]
681    async fn multiple_chunks_write_stream() {
682        // Run many times because the test is nondeterministic:
683        for n in 0..TEST_ITERATIONS {
684            multiple_chunks_write_stream_aux(n).await
685        }
686    }
687    #[tracing::instrument]
688    async fn multiple_chunks_write_stream_aux(_: usize) {
689        use std::ops::Deref;
690
691        let (mut reader, writer) = simplex(1024);
692        let mut writer = AsyncWriteStream::new(1024, writer);
693
694        // Write a chunk:
695        let chunk = Bytes::from_static(&[123; 1]);
696
697        let permit = resolves_immediately(writer.write_ready())
698            .await
699            .expect("write should be ready");
700        assert_eq!(permit, 1024);
701
702        writer.write(chunk.clone()).expect("write does not trap");
703
704        // At this point the message will either be waiting for the worker to process the write, or
705        // it will be buffered in the simplex channel.
706        let permit = resolves_immediately(writer.write_ready())
707            .await
708            .expect("write should be ready");
709        assert!(matches!(permit, 1023 | 1024));
710
711        let mut read_buf = vec![0; chunk.len()];
712        let read_len = reader.read_exact(&mut read_buf).await.unwrap();
713        assert_eq!(read_len, chunk.len());
714        assert_eq!(read_buf.as_slice(), chunk.deref());
715
716        // Write a second, different chunk:
717        let chunk2 = Bytes::from_static(&[45; 1]);
718
719        // We're only guaranteed to see a consistent write budget if we flush.
720        writer.flush().expect("channel is still alive");
721
722        let permit = resolves_immediately(writer.write_ready())
723            .await
724            .expect("write should be ready");
725        assert_eq!(permit, 1024);
726
727        writer.write(chunk2.clone()).expect("write does not trap");
728
729        // At this point the message will either be waiting for the worker to process the write, or
730        // it will be buffered in the simplex channel.
731        let permit = resolves_immediately(writer.write_ready())
732            .await
733            .expect("write should be ready");
734        assert!(matches!(permit, 1023 | 1024));
735
736        let mut read2_buf = vec![0; chunk2.len()];
737        let read2_len = reader.read_exact(&mut read2_buf).await.unwrap();
738        assert_eq!(read2_len, chunk2.len());
739        assert_eq!(read2_buf.as_slice(), chunk2.deref());
740
741        // We're only guaranteed to see a consistent write budget if we flush.
742        writer.flush().expect("channel is still alive");
743
744        let permit = resolves_immediately(writer.write_ready())
745            .await
746            .expect("write should be ready");
747        assert_eq!(permit, 1024);
748    }
749
750    #[test_log::test(tokio::test(flavor = "multi_thread"))]
751    async fn backpressure_write_stream() {
752        // Run many times because the test is nondeterministic:
753        for n in 0..TEST_ITERATIONS {
754            backpressure_write_stream_aux(n).await
755        }
756    }
757    #[tracing::instrument]
758    async fn backpressure_write_stream_aux(_: usize) {
759        use futures::future::poll_immediate;
760
761        // The channel can buffer up to 1k, plus another 1k in the stream, before not
762        // accepting more input:
763        let (mut reader, writer) = simplex(1024);
764        let mut writer = AsyncWriteStream::new(1024, writer);
765
766        let chunk = Bytes::from_static(&[0; 1024]);
767
768        let permit = resolves_immediately(writer.write_ready())
769            .await
770            .expect("write should be ready");
771        assert_eq!(permit, 1024);
772
773        writer.write(chunk.clone()).expect("write succeeds");
774
775        // We might still be waiting for the worker to process the message, or the worker may have
776        // processed it and released all the budget back to us.
777        let permit = poll_immediate(writer.write_ready()).await;
778        assert!(matches!(permit, None | Some(Ok(1024))));
779
780        // Given a little time, the worker will process the message and release all the budget
781        // back.
782        let permit = resolves_immediately(writer.write_ready())
783            .await
784            .expect("write should be ready");
785        assert_eq!(permit, 1024);
786
787        // Now fill the buffer between here and the writer task. This should always indicate
788        // back-pressure because now both buffers (simplex and worker) are full.
789        writer.write(chunk.clone()).expect("write does not trap");
790
791        // Try shoving even more down there, and it shouldn't accept more input:
792        writer
793            .write(chunk.clone())
794            .err()
795            .expect("unpermitted write does trap");
796
797        // No amount of waiting will resolve the situation, as nothing is emptying the simplex
798        // buffer.
799        never_resolves(writer.write_ready()).await;
800
801        // There is 2k buffered between the simplex and worker buffers. I should be able to read
802        // all of it out:
803        let mut buf = [0; 2048];
804        reader.read_exact(&mut buf).await.unwrap();
805
806        // and no more:
807        never_resolves(reader.read(&mut buf)).await;
808
809        // Now the backpressure should be cleared, and an additional write should be accepted.
810        let permit = resolves_immediately(writer.write_ready())
811            .await
812            .expect("ready is ok");
813        assert_eq!(permit, 1024);
814
815        // and the write succeeds:
816        writer.write(chunk.clone()).expect("write does not trap");
817    }
818
819    #[test_log::test(tokio::test(flavor = "multi_thread"))]
820    async fn backpressure_write_stream_with_flush() {
821        for n in 0..TEST_ITERATIONS {
822            backpressure_write_stream_with_flush_aux(n).await;
823        }
824    }
825
826    async fn backpressure_write_stream_with_flush_aux(_: usize) {
827        // The channel can buffer up to 1k, plus another 1k in the stream, before not
828        // accepting more input:
829        let (mut reader, writer) = simplex(1024);
830        let mut writer = AsyncWriteStream::new(1024, writer);
831
832        let chunk = Bytes::from_static(&[0; 1024]);
833
834        let permit = resolves_immediately(writer.write_ready())
835            .await
836            .expect("write should be ready");
837        assert_eq!(permit, 1024);
838
839        writer.write(chunk.clone()).expect("write succeeds");
840
841        writer.flush().expect("flush succeeds");
842
843        // Waiting for write_ready to resolve after a flush should always show that we have the
844        // full budget available, as the message will have flushed to the simplex channel.
845        let permit = resolves_immediately(writer.write_ready())
846            .await
847            .expect("write_ready succeeds");
848        assert_eq!(permit, 1024);
849
850        // Write enough to fill the simplex buffer:
851        writer.write(chunk.clone()).expect("write does not trap");
852
853        // Writes should be refused until this flush succeeds.
854        writer.flush().expect("flush succeeds");
855
856        // Try shoving even more down there, and it shouldn't accept more input:
857        writer
858            .write(chunk.clone())
859            .err()
860            .expect("unpermitted write does trap");
861
862        // No amount of waiting will resolve the situation, as nothing is emptying the simplex
863        // buffer.
864        never_resolves(writer.write_ready()).await;
865
866        // There is 2k buffered between the simplex and worker buffers. I should be able to read
867        // all of it out:
868        let mut buf = [0; 2048];
869        reader.read_exact(&mut buf).await.unwrap();
870
871        // and no more:
872        never_resolves(reader.read(&mut buf)).await;
873
874        // Now the backpressure should be cleared, and an additional write should be accepted.
875        let permit = resolves_immediately(writer.write_ready())
876            .await
877            .expect("ready is ok");
878        assert_eq!(permit, 1024);
879
880        // and the write succeeds:
881        writer.write(chunk.clone()).expect("write does not trap");
882
883        writer.flush().expect("flush succeeds");
884
885        let permit = resolves_immediately(writer.write_ready())
886            .await
887            .expect("ready is ok");
888        assert_eq!(permit, 1024);
889    }
890}