1use anyhow::anyhow;
11use bytes::Bytes;
12use std::pin::{Pin, pin};
13use std::sync::{Arc, Mutex};
14use std::task::{Context, Poll};
15use tokio::io::{self, AsyncRead, AsyncWrite};
16use tokio::sync::mpsc;
17use wasmtime_wasi_io::{
18 poll::Pollable,
19 streams::{InputStream, OutputStream, StreamError},
20};
21
22pub use crate::p2::write_stream::AsyncWriteStream;
23
24#[derive(Debug, Clone)]
25pub struct MemoryInputPipe {
26 buffer: Arc<Mutex<Bytes>>,
27}
28
29impl MemoryInputPipe {
30 pub fn new(bytes: impl Into<Bytes>) -> Self {
31 Self {
32 buffer: Arc::new(Mutex::new(bytes.into())),
33 }
34 }
35
36 pub fn is_empty(&self) -> bool {
37 self.buffer.lock().unwrap().is_empty()
38 }
39}
40
41#[async_trait::async_trait]
42impl InputStream for MemoryInputPipe {
43 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
44 let mut buffer = self.buffer.lock().unwrap();
45 if buffer.is_empty() {
46 return Err(StreamError::Closed);
47 }
48
49 let size = size.min(buffer.len());
50 let read = buffer.split_to(size);
51 Ok(read)
52 }
53}
54
55#[async_trait::async_trait]
56impl Pollable for MemoryInputPipe {
57 async fn ready(&mut self) {}
58}
59
60impl AsyncRead for MemoryInputPipe {
61 fn poll_read(
62 self: Pin<&mut Self>,
63 _cx: &mut Context<'_>,
64 buf: &mut io::ReadBuf<'_>,
65 ) -> Poll<io::Result<()>> {
66 let mut buffer = self.buffer.lock().unwrap();
67 let size = buf.remaining().min(buffer.len());
68 let read = buffer.split_to(size);
69 buf.put_slice(&read);
70 Poll::Ready(Ok(()))
71 }
72}
73
74#[derive(Debug, Clone)]
75pub struct MemoryOutputPipe {
76 capacity: usize,
77 buffer: Arc<Mutex<bytes::BytesMut>>,
78}
79
80impl MemoryOutputPipe {
81 pub fn new(capacity: usize) -> Self {
82 MemoryOutputPipe {
83 capacity,
84 buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())),
85 }
86 }
87
88 pub fn contents(&self) -> bytes::Bytes {
89 self.buffer.lock().unwrap().clone().freeze()
90 }
91
92 pub fn try_into_inner(self) -> Option<bytes::BytesMut> {
93 std::sync::Arc::into_inner(self.buffer).map(|m| m.into_inner().unwrap())
94 }
95}
96
97#[async_trait::async_trait]
98impl OutputStream for MemoryOutputPipe {
99 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
100 let mut buf = self.buffer.lock().unwrap();
101 if bytes.len() > self.capacity - buf.len() {
102 return Err(StreamError::Trap(anyhow!(
103 "write beyond capacity of MemoryOutputPipe"
104 )));
105 }
106 buf.extend_from_slice(bytes.as_ref());
107 Ok(())
109 }
110 fn flush(&mut self) -> Result<(), StreamError> {
111 Ok(())
113 }
114 fn check_write(&mut self) -> Result<usize, StreamError> {
115 let consumed = self.buffer.lock().unwrap().len();
116 if consumed < self.capacity {
117 Ok(self.capacity - consumed)
118 } else {
119 Err(StreamError::Closed)
121 }
122 }
123}
124
125#[async_trait::async_trait]
126impl Pollable for MemoryOutputPipe {
127 async fn ready(&mut self) {}
128}
129
130impl AsyncWrite for MemoryOutputPipe {
131 fn poll_write(
132 self: Pin<&mut Self>,
133 _cx: &mut Context<'_>,
134 buf: &[u8],
135 ) -> Poll<io::Result<usize>> {
136 let mut buffer = self.buffer.lock().unwrap();
137 let amt = buf.len().min(self.capacity - buffer.len());
138 buffer.extend_from_slice(&buf[..amt]);
139 Poll::Ready(Ok(amt))
140 }
141 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
142 Poll::Ready(Ok(()))
143 }
144 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145 Poll::Ready(Ok(()))
146 }
147}
148
149pub struct AsyncReadStream {
151 closed: bool,
152 buffer: Option<Result<Bytes, StreamError>>,
153 receiver: mpsc::Receiver<Result<Bytes, StreamError>>,
154 join_handle: Option<crate::runtime::AbortOnDropJoinHandle<()>>,
155}
156
157impl AsyncReadStream {
158 pub fn new<T: AsyncRead + Send + 'static>(reader: T) -> Self {
161 let (sender, receiver) = mpsc::channel(1);
162 let join_handle = crate::runtime::spawn(async move {
163 let mut reader = pin!(reader);
164 loop {
165 use tokio::io::AsyncReadExt;
166 let mut buf = bytes::BytesMut::with_capacity(4096);
167 let sent = match reader.read_buf(&mut buf).await {
168 Ok(nbytes) if nbytes == 0 => sender.send(Err(StreamError::Closed)).await,
169 Ok(_) => sender.send(Ok(buf.freeze())).await,
170 Err(e) => {
171 sender
172 .send(Err(StreamError::LastOperationFailed(e.into())))
173 .await
174 }
175 };
176 if sent.is_err() {
177 break;
179 }
180 }
181 });
182 AsyncReadStream {
183 closed: false,
184 buffer: None,
185 receiver,
186 join_handle: Some(join_handle),
187 }
188 }
189 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
190 if self.buffer.is_some() || self.closed {
191 return Poll::Ready(());
192 }
193 match self.receiver.poll_recv(cx) {
194 Poll::Ready(Some(res)) => {
195 self.buffer = Some(res);
196 Poll::Ready(())
197 }
198 Poll::Ready(None) => {
199 panic!("no more sender for an open AsyncReadStream - should be impossible")
200 }
201 Poll::Pending => Poll::Pending,
202 }
203 }
204}
205
206#[async_trait::async_trait]
207impl InputStream for AsyncReadStream {
208 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
209 use mpsc::error::TryRecvError;
210
211 match self.buffer.take() {
212 Some(Ok(mut bytes)) => {
213 let len = bytes.len().min(size);
215 let rest = bytes.split_off(len);
216 if !rest.is_empty() {
217 self.buffer = Some(Ok(rest));
218 }
219 return Ok(bytes);
220 }
221 Some(Err(e)) => {
222 self.closed = true;
223 return Err(e);
224 }
225 None => {}
226 }
227
228 match self.receiver.try_recv() {
229 Ok(Ok(mut bytes)) => {
230 let len = bytes.len().min(size);
231 let rest = bytes.split_off(len);
232 if !rest.is_empty() {
233 self.buffer = Some(Ok(rest));
234 }
235
236 Ok(bytes)
237 }
238 Ok(Err(e)) => {
239 self.closed = true;
240 Err(e)
241 }
242 Err(TryRecvError::Empty) => Ok(Bytes::new()),
243 Err(TryRecvError::Disconnected) => Err(StreamError::Trap(anyhow!(
244 "AsyncReadStream sender died - should be impossible"
245 ))),
246 }
247 }
248
249 async fn cancel(&mut self) {
250 match self.join_handle.take() {
251 Some(task) => _ = task.cancel().await,
252 None => {}
253 }
254 }
255}
256
257#[async_trait::async_trait]
258impl Pollable for AsyncReadStream {
259 async fn ready(&mut self) {
260 std::future::poll_fn(|cx| self.poll_ready(cx)).await
261 }
262}
263
264#[derive(Copy, Clone)]
266pub struct SinkOutputStream;
267
268#[async_trait::async_trait]
269impl OutputStream for SinkOutputStream {
270 fn write(&mut self, _buf: Bytes) -> Result<(), StreamError> {
271 Ok(())
272 }
273 fn flush(&mut self) -> Result<(), StreamError> {
274 Ok(())
276 }
277
278 fn check_write(&mut self) -> Result<usize, StreamError> {
279 Ok(usize::MAX)
281 }
282}
283
284#[async_trait::async_trait]
285impl Pollable for SinkOutputStream {
286 async fn ready(&mut self) {}
287}
288
289#[derive(Copy, Clone)]
291pub struct ClosedInputStream;
292
293#[async_trait::async_trait]
294impl InputStream for ClosedInputStream {
295 fn read(&mut self, _size: usize) -> Result<Bytes, StreamError> {
296 Err(StreamError::Closed)
297 }
298}
299
300#[async_trait::async_trait]
301impl Pollable for ClosedInputStream {
302 async fn ready(&mut self) {}
303}
304
305#[derive(Copy, Clone)]
307pub struct ClosedOutputStream;
308
309#[async_trait::async_trait]
310impl OutputStream for ClosedOutputStream {
311 fn write(&mut self, _: Bytes) -> Result<(), StreamError> {
312 Err(StreamError::Closed)
313 }
314 fn flush(&mut self) -> Result<(), StreamError> {
315 Err(StreamError::Closed)
316 }
317
318 fn check_write(&mut self) -> Result<usize, StreamError> {
319 Err(StreamError::Closed)
320 }
321}
322
323#[async_trait::async_trait]
324impl Pollable for ClosedOutputStream {
325 async fn ready(&mut self) {}
326}
327
328#[cfg(test)]
329mod test {
330 use super::*;
331 use std::time::Duration;
332 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
333
334 #[cfg(not(target_arch = "x86_64"))]
336 const TEST_ITERATIONS: usize = 10;
337
338 #[cfg(target_arch = "x86_64")]
339 const TEST_ITERATIONS: usize = 100;
340
341 async fn resolves_immediately<F, O>(fut: F) -> O
342 where
343 F: futures::Future<Output = O>,
344 {
345 tokio::time::timeout(Duration::from_secs(2), fut)
349 .await
350 .expect("operation timed out")
351 }
352
353 async fn never_resolves<F: futures::Future>(fut: F) {
354 tokio::time::timeout(Duration::from_millis(10), fut)
358 .await
359 .err()
360 .expect("operation should time out");
361 }
362
363 pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) {
364 let (a, b) = tokio::io::duplex(size);
365 let (_read_half, write_half) = tokio::io::split(a);
366 let (read_half, _write_half) = tokio::io::split(b);
367 (read_half, write_half)
368 }
369
370 #[test_log::test(tokio::test(flavor = "multi_thread"))]
371 async fn empty_read_stream() {
372 let mut reader = AsyncReadStream::new(tokio::io::empty());
373
374 match reader.read(10) {
377 Err(StreamError::Closed) => {}
379
380 Ok(bs) => {
382 assert!(bs.is_empty());
383 resolves_immediately(reader.ready()).await;
384 assert!(matches!(reader.read(0), Err(StreamError::Closed)));
385 }
386 res => panic!("unexpected: {res:?}"),
387 }
388 }
389
390 #[test_log::test(tokio::test(flavor = "multi_thread"))]
391 async fn infinite_read_stream() {
392 let mut reader = AsyncReadStream::new(tokio::io::repeat(0));
393
394 let bs = reader.read(10).unwrap();
395 if bs.is_empty() {
396 resolves_immediately(reader.ready()).await;
398 let bs = reader.read(10).unwrap();
400 assert_eq!(bs.len(), 10);
401 } else {
402 assert_eq!(bs.len(), 10);
403 }
404
405 let bs = reader.read(10).unwrap();
407 assert_eq!(bs.len(), 10);
408
409 let bs = reader.read(0).unwrap();
411 assert_eq!(bs.len(), 0);
412 }
413
414 async fn finite_async_reader(contents: &[u8]) -> impl AsyncRead + Send + 'static + use<> {
415 let (r, mut w) = simplex(contents.len());
416 w.write_all(contents).await.unwrap();
417 r
418 }
419
420 #[test_log::test(tokio::test(flavor = "multi_thread"))]
421 async fn finite_read_stream() {
422 let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await);
423
424 let bs = reader.read(123).unwrap();
425 if bs.is_empty() {
426 resolves_immediately(reader.ready()).await;
428 let bs = reader.read(123).unwrap();
430 assert_eq!(bs.len(), 123);
431 } else {
432 assert_eq!(bs.len(), 123);
433 }
434
435 match reader.read(0) {
438 Err(StreamError::Closed) => {} Ok(bs) => {
440 assert!(bs.is_empty());
441 resolves_immediately(reader.ready()).await;
443 assert!(matches!(reader.read(0), Err(StreamError::Closed)));
445 }
446 res => panic!("unexpected: {res:?}"),
447 }
448 }
449
450 #[test_log::test(tokio::test(flavor = "multi_thread"))]
451 async fn multiple_chunks_read_stream() {
454 let (r, mut w) = simplex(1024);
455 let mut reader = AsyncReadStream::new(r);
456
457 w.write_all(&[123]).await.unwrap();
458
459 let bs = reader.read(1).unwrap();
460 if bs.is_empty() {
461 resolves_immediately(reader.ready()).await;
463 let bs = reader.read(1).unwrap();
465 assert_eq!(*bs, [123u8]);
466 } else {
467 assert_eq!(*bs, [123u8]);
468 }
469
470 let bs = reader.read(1).unwrap();
472 assert!(bs.is_empty());
473
474 never_resolves(reader.ready()).await;
476
477 let bs = reader.read(1).unwrap();
479 assert!(bs.is_empty());
480
481 w.write_all(&[45]).await.unwrap();
483
484 resolves_immediately(reader.ready()).await;
487
488 let bs = reader.read(1).unwrap();
490 assert_eq!(*bs, [45u8]);
491
492 let bs = reader.read(1).unwrap();
494 assert!(bs.is_empty());
495
496 never_resolves(reader.ready()).await;
498
499 let bs = reader.read(1).unwrap();
501 assert!(bs.is_empty());
502
503 drop(w);
505
506 resolves_immediately(reader.ready()).await;
509
510 assert!(matches!(reader.read(1), Err(StreamError::Closed)));
512 }
513
514 #[test_log::test(tokio::test(flavor = "multi_thread"))]
515 async fn backpressure_read_stream() {
519 let (r, mut w) = simplex(16 * 1024); let mut reader = AsyncReadStream::new(r);
521
522 let writer_task = tokio::task::spawn(async move {
523 w.write_all(&[123; 8192]).await.unwrap();
525 w
526 });
527
528 resolves_immediately(reader.ready()).await;
529
530 let bs = reader.read(4097).unwrap();
533 assert_eq!(bs.len(), 4096);
534
535 resolves_immediately(reader.ready()).await;
537
538 let bs = reader.read(4097).unwrap();
541 assert_eq!(bs.len(), 4096);
542
543 let w = resolves_immediately(writer_task).await;
545
546 drop(w);
548
549 resolves_immediately(reader.ready()).await;
551
552 assert!(matches!(reader.read(4097), Err(StreamError::Closed)));
554 }
555
556 #[test_log::test(test_log::test(tokio::test(flavor = "multi_thread")))]
557 async fn sink_write_stream() {
558 let mut writer = AsyncWriteStream::new(2048, tokio::io::sink());
559 let chunk = Bytes::from_static(&[0; 1024]);
560
561 let readiness = resolves_immediately(writer.write_ready())
562 .await
563 .expect("write_ready does not trap");
564 assert_eq!(readiness, 2048);
565 writer.write(chunk.clone()).expect("write does not error");
567
568 let readiness = resolves_immediately(writer.write_ready())
570 .await
571 .expect("write_ready does not trap");
572 assert!(
573 readiness == 1024 || readiness == 2048,
574 "readiness should be 1024 or 2048, got {readiness}"
575 );
576
577 if readiness == 1024 {
578 writer.write(chunk.clone()).expect("write does not error");
579
580 let readiness = resolves_immediately(writer.write_ready())
581 .await
582 .expect("write_ready does not trap");
583 assert!(
584 readiness == 1024 || readiness == 2048,
585 "readiness should be 1024 or 2048, got {readiness}"
586 );
587 }
588 }
589
590 #[test_log::test(tokio::test(flavor = "multi_thread"))]
591 async fn closed_write_stream() {
592 for n in 0..TEST_ITERATIONS {
594 closed_write_stream_(n).await
595 }
596 }
597 #[tracing::instrument]
598 async fn closed_write_stream_(n: usize) {
599 let (reader, writer) = simplex(1);
600 let mut writer = AsyncWriteStream::new(1024, writer);
601
602 drop(reader);
604
605 let mut should_be_closed = false;
608
609 let chunk = Bytes::from_static(&[0; 1]);
611 writer
612 .write(chunk.clone())
613 .expect("first write should succeed");
614
615 let mut write_ready_res = None;
617 if n % 2 == 0 {
618 let r = resolves_immediately(writer.write_ready()).await;
619 match r {
621 Ok(1023) => {}
623 Err(StreamError::LastOperationFailed(_)) => {
625 tracing::debug!("discovered stream failure in first write_ready");
626 should_be_closed = true;
627 }
628 r => panic!("unexpected write_ready: {r:?}"),
629 }
630 write_ready_res = Some(r);
631 }
632
633 let flush_res = writer.flush();
636 match flush_res {
637 Err(StreamError::LastOperationFailed(_)) => {
639 tracing::debug!("discovered stream failure trying to flush");
640 assert!(!should_be_closed);
641 should_be_closed = true;
642 }
643 Err(StreamError::Closed) => {
645 assert!(
646 should_be_closed,
647 "expected a LastOperationFailed before we see Closed. {write_ready_res:?}"
648 );
649 }
650 Ok(()) => {}
652 Err(e) => panic!("unexpected flush error: {e:?} {write_ready_res:?}"),
653 }
654
655 match resolves_immediately(writer.write_ready()).await {
658 Err(StreamError::LastOperationFailed(_)) => {
660 tracing::debug!("discovered stream failure trying to flush");
661 assert!(!should_be_closed);
662 }
663 Err(StreamError::Closed) => {
665 assert!(should_be_closed);
666 }
667 r => {
668 panic!(
669 "stream should be reported closed by the end of write_ready after flush, got {r:?}. {write_ready_res:?} {flush_res:?}"
670 )
671 }
672 }
673 }
674
675 #[test_log::test(tokio::test(flavor = "multi_thread"))]
676 async fn multiple_chunks_write_stream() {
677 for n in 0..TEST_ITERATIONS {
679 multiple_chunks_write_stream_aux(n).await
680 }
681 }
682 #[tracing::instrument]
683 async fn multiple_chunks_write_stream_aux(_: usize) {
684 use std::ops::Deref;
685
686 let (mut reader, writer) = simplex(1024);
687 let mut writer = AsyncWriteStream::new(1024, writer);
688
689 let chunk = Bytes::from_static(&[123; 1]);
691
692 let permit = resolves_immediately(writer.write_ready())
693 .await
694 .expect("write should be ready");
695 assert_eq!(permit, 1024);
696
697 writer.write(chunk.clone()).expect("write does not trap");
698
699 let permit = resolves_immediately(writer.write_ready())
702 .await
703 .expect("write should be ready");
704 assert!(matches!(permit, 1023 | 1024));
705
706 let mut read_buf = vec![0; chunk.len()];
707 let read_len = reader.read_exact(&mut read_buf).await.unwrap();
708 assert_eq!(read_len, chunk.len());
709 assert_eq!(read_buf.as_slice(), chunk.deref());
710
711 let chunk2 = Bytes::from_static(&[45; 1]);
713
714 writer.flush().expect("channel is still alive");
716
717 let permit = resolves_immediately(writer.write_ready())
718 .await
719 .expect("write should be ready");
720 assert_eq!(permit, 1024);
721
722 writer.write(chunk2.clone()).expect("write does not trap");
723
724 let permit = resolves_immediately(writer.write_ready())
727 .await
728 .expect("write should be ready");
729 assert!(matches!(permit, 1023 | 1024));
730
731 let mut read2_buf = vec![0; chunk2.len()];
732 let read2_len = reader.read_exact(&mut read2_buf).await.unwrap();
733 assert_eq!(read2_len, chunk2.len());
734 assert_eq!(read2_buf.as_slice(), chunk2.deref());
735
736 writer.flush().expect("channel is still alive");
738
739 let permit = resolves_immediately(writer.write_ready())
740 .await
741 .expect("write should be ready");
742 assert_eq!(permit, 1024);
743 }
744
745 #[test_log::test(tokio::test(flavor = "multi_thread"))]
746 async fn backpressure_write_stream() {
747 for n in 0..TEST_ITERATIONS {
749 backpressure_write_stream_aux(n).await
750 }
751 }
752 #[tracing::instrument]
753 async fn backpressure_write_stream_aux(_: usize) {
754 use futures::future::poll_immediate;
755
756 let (mut reader, writer) = simplex(1024);
759 let mut writer = AsyncWriteStream::new(1024, writer);
760
761 let chunk = Bytes::from_static(&[0; 1024]);
762
763 let permit = resolves_immediately(writer.write_ready())
764 .await
765 .expect("write should be ready");
766 assert_eq!(permit, 1024);
767
768 writer.write(chunk.clone()).expect("write succeeds");
769
770 let permit = poll_immediate(writer.write_ready()).await;
773 assert!(matches!(permit, None | Some(Ok(1024))));
774
775 let permit = resolves_immediately(writer.write_ready())
778 .await
779 .expect("write should be ready");
780 assert_eq!(permit, 1024);
781
782 writer.write(chunk.clone()).expect("write does not trap");
785
786 writer
788 .write(chunk.clone())
789 .err()
790 .expect("unpermitted write does trap");
791
792 never_resolves(writer.write_ready()).await;
795
796 let mut buf = [0; 2048];
799 reader.read_exact(&mut buf).await.unwrap();
800
801 never_resolves(reader.read(&mut buf)).await;
803
804 let permit = resolves_immediately(writer.write_ready())
806 .await
807 .expect("ready is ok");
808 assert_eq!(permit, 1024);
809
810 writer.write(chunk.clone()).expect("write does not trap");
812 }
813
814 #[test_log::test(tokio::test(flavor = "multi_thread"))]
815 async fn backpressure_write_stream_with_flush() {
816 for n in 0..TEST_ITERATIONS {
817 backpressure_write_stream_with_flush_aux(n).await;
818 }
819 }
820
821 async fn backpressure_write_stream_with_flush_aux(_: usize) {
822 let (mut reader, writer) = simplex(1024);
825 let mut writer = AsyncWriteStream::new(1024, writer);
826
827 let chunk = Bytes::from_static(&[0; 1024]);
828
829 let permit = resolves_immediately(writer.write_ready())
830 .await
831 .expect("write should be ready");
832 assert_eq!(permit, 1024);
833
834 writer.write(chunk.clone()).expect("write succeeds");
835
836 writer.flush().expect("flush succeeds");
837
838 let permit = resolves_immediately(writer.write_ready())
841 .await
842 .expect("write_ready succeeds");
843 assert_eq!(permit, 1024);
844
845 writer.write(chunk.clone()).expect("write does not trap");
847
848 writer.flush().expect("flush succeeds");
850
851 writer
853 .write(chunk.clone())
854 .err()
855 .expect("unpermitted write does trap");
856
857 never_resolves(writer.write_ready()).await;
860
861 let mut buf = [0; 2048];
864 reader.read_exact(&mut buf).await.unwrap();
865
866 never_resolves(reader.read(&mut buf)).await;
868
869 let permit = resolves_immediately(writer.write_ready())
871 .await
872 .expect("ready is ok");
873 assert_eq!(permit, 1024);
874
875 writer.write(chunk.clone()).expect("write does not trap");
877
878 writer.flush().expect("flush succeeds");
879
880 let permit = resolves_immediately(writer.write_ready())
881 .await
882 .expect("ready is ok");
883 assert_eq!(permit, 1024);
884 }
885}