Skip to main content

wasmtime_wasi_http/p2/
body.rs

1//! Implementation of the `wasi:http/types` interface's various body types.
2
3use crate::FieldMap;
4use crate::p2::bindings::http::types;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::BodyExt;
8use http_body_util::combinators::UnsyncBoxBody;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime::format_err;
15use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
16use wasmtime_wasi::runtime::{AbortOnDropJoinHandle, poll_noop};
17
18/// Common type for incoming bodies.
19pub type HyperIncomingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
20
21/// Common type for outgoing bodies.
22pub type HyperOutgoingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
23
24/// The concrete type behind a `was:http/types.incoming-body` resource.
25#[derive(Debug)]
26pub struct HostIncomingBody {
27    body: IncomingBodyState,
28    /// An optional worker task to keep alive while this body is being read.
29    /// This ensures that if the parent of this body is dropped before the body
30    /// then the backing data behind this worker is kept alive.
31    worker: Option<AbortOnDropJoinHandle<()>>,
32}
33
34impl HostIncomingBody {
35    /// Create a new `HostIncomingBody` with the given `body` and a per-frame timeout
36    pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
37        let body = BodyWithTimeout::new(body, between_bytes_timeout);
38        HostIncomingBody {
39            body: IncomingBodyState::Start(body),
40            worker: None,
41        }
42    }
43
44    /// Retain a worker task that needs to be kept alive while this body is being read.
45    pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
46        assert!(self.worker.is_none());
47        self.worker = Some(worker);
48    }
49
50    /// Try taking the stream of this body, if it's available.
51    pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
52        match &mut self.body {
53            IncomingBodyState::Start(_) => {}
54            IncomingBodyState::InBodyStream(_) => return None,
55        }
56        let (tx, rx) = oneshot::channel();
57        let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
58            IncomingBodyState::Start(b) => b,
59            IncomingBodyState::InBodyStream(_) => unreachable!(),
60        };
61        Some(HostIncomingBodyStream {
62            state: IncomingBodyStreamState::Open { body, tx },
63            buffer: Bytes::new(),
64            error: None,
65        })
66    }
67
68    /// Convert this body into a `HostFutureTrailers` resource.
69    pub fn into_future_trailers(self) -> HostFutureTrailers {
70        HostFutureTrailers::Waiting(self)
71    }
72}
73
74/// Internal state of a [`HostIncomingBody`].
75#[derive(Debug)]
76enum IncomingBodyState {
77    /// The body is stored here meaning that within `HostIncomingBody` the
78    /// `take_stream` method can be called for example.
79    Start(BodyWithTimeout),
80
81    /// The body is within a `HostIncomingBodyStream` meaning that it's not
82    /// currently owned here. The body will be sent back over this channel when
83    /// it's done, however.
84    InBodyStream(oneshot::Receiver<StreamEnd>),
85}
86
87/// Small wrapper around [`HyperIncomingBody`] which adds a timeout to every frame.
88#[derive(Debug)]
89struct BodyWithTimeout {
90    /// Underlying stream that frames are coming from.
91    inner: HyperIncomingBody,
92    /// Currently active timeout that's reset between frames.
93    timeout: Pin<Box<tokio::time::Sleep>>,
94    /// Whether or not `timeout` needs to be reset on the next call to
95    /// `poll_frame`.
96    reset_sleep: bool,
97    /// Maximal duration between when a frame is first requested and when it's
98    /// allowed to arrive.
99    between_bytes_timeout: Duration,
100}
101
102impl BodyWithTimeout {
103    fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
104        BodyWithTimeout {
105            inner,
106            between_bytes_timeout,
107            reset_sleep: true,
108            timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
109                tokio::time::sleep(Duration::new(0, 0))
110            })),
111        }
112    }
113}
114
115impl Body for BodyWithTimeout {
116    type Data = Bytes;
117    type Error = types::ErrorCode;
118
119    fn poll_frame(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122    ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
123        let me = Pin::into_inner(self);
124
125        // If the timeout timer needs to be reset, do that now relative to the
126        // current instant. Otherwise test the timeout timer and see if it's
127        // fired yet and if so we've timed out and return an error.
128        if me.reset_sleep {
129            me.timeout
130                .as_mut()
131                .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
132            me.reset_sleep = false;
133        }
134
135        // Register interest in this context on the sleep timer, and if the
136        // sleep elapsed that means that we've timed out.
137        if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
138            return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
139        }
140
141        // Without timeout business now handled check for the frame. If a frame
142        // arrives then the sleep timer will be reset on the next frame.
143        let result = Pin::new(&mut me.inner).poll_frame(cx);
144        me.reset_sleep = result.is_ready();
145        result
146    }
147}
148
149/// Message sent when a `HostIncomingBodyStream` is done to the
150/// `HostFutureTrailers` state.
151#[derive(Debug)]
152enum StreamEnd {
153    /// The body wasn't completely read and was dropped early. May still have
154    /// trailers, but requires reading more frames.
155    Remaining(BodyWithTimeout),
156
157    /// Body was completely read and trailers were read. Here are the trailers.
158    /// Note that `None` means that the body finished without trailers.
159    Trailers(Option<http::HeaderMap>),
160}
161
162/// The concrete type behind the `wasi:io/streams.input-stream` resource returned
163/// by `wasi:http/types.incoming-body`'s `stream` method.
164#[derive(Debug)]
165pub struct HostIncomingBodyStream {
166    state: IncomingBodyStreamState,
167    buffer: Bytes,
168    error: Option<wasmtime::Error>,
169}
170
171impl HostIncomingBodyStream {
172    fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
173        match frame {
174            Some(Ok(frame)) => match frame.into_data() {
175                // A data frame was received, so queue up the buffered data for
176                // the next `read` call.
177                Ok(bytes) => {
178                    assert!(self.buffer.is_empty());
179                    self.buffer = bytes;
180                }
181
182                // Trailers were received meaning that this was the final frame.
183                // Throw away the body and send the trailers along the
184                // `tx` channel to make them available.
185                Err(trailers) => {
186                    let trailers = trailers.into_trailers().unwrap();
187                    let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
188                        IncomingBodyStreamState::Open { body: _, tx } => tx,
189                        IncomingBodyStreamState::Closed => unreachable!(),
190                    };
191
192                    // NB: ignore send failures here because if this fails then
193                    // no one was interested in the trailers.
194                    let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
195                }
196            },
197
198            // An error was received meaning that the stream is now done.
199            // Destroy the body to terminate the stream while enqueueing the
200            // error to get returned from the next call to `read`.
201            Some(Err(e)) => {
202                self.error = Some(e.into());
203                self.state = IncomingBodyStreamState::Closed;
204            }
205
206            // No more frames are going to be received again, so drop the `body`
207            // and the `tx` channel we'd send the body back onto because it's
208            // not needed as frames are done.
209            None => {
210                self.state = IncomingBodyStreamState::Closed;
211            }
212        }
213    }
214}
215
216#[derive(Debug)]
217enum IncomingBodyStreamState {
218    /// The body is currently open for reading and present here.
219    ///
220    /// When trailers are read, or when this is dropped, the body is sent along
221    /// `tx`.
222    ///
223    /// This state is transitioned to `Closed` when an error happens, EOF
224    /// happens, or when trailers are read.
225    Open {
226        body: BodyWithTimeout,
227        tx: oneshot::Sender<StreamEnd>,
228    },
229
230    /// This body is closed and no longer available for reading, no more data
231    /// will come.
232    Closed,
233}
234
235#[async_trait::async_trait]
236impl InputStream for HostIncomingBodyStream {
237    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
238        loop {
239            // Handle buffered data/errors if any
240            if !self.buffer.is_empty() {
241                let len = size.min(self.buffer.len());
242                let chunk = self.buffer.split_to(len);
243                return Ok(chunk);
244            }
245
246            if let Some(e) = self.error.take() {
247                return Err(StreamError::LastOperationFailed(e));
248            }
249
250            // Extract the body that we're reading from. If present perform a
251            // non-blocking poll to see if a frame is already here. If it is
252            // then turn the loop again to operate on the results. If it's not
253            // here then return an empty buffer as no data is available at this
254            // time.
255            let body = match &mut self.state {
256                IncomingBodyStreamState::Open { body, .. } => body,
257                IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
258            };
259
260            let future = body.frame();
261            futures::pin_mut!(future);
262            match poll_noop(future) {
263                Some(result) => {
264                    self.record_frame(result);
265                }
266                None => return Ok(Bytes::new()),
267            }
268        }
269    }
270}
271
272#[async_trait::async_trait]
273impl Pollable for HostIncomingBodyStream {
274    async fn ready(&mut self) {
275        if !self.buffer.is_empty() || self.error.is_some() {
276            return;
277        }
278
279        if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
280            let frame = body.frame().await;
281            self.record_frame(frame);
282        }
283    }
284}
285
286impl Drop for HostIncomingBodyStream {
287    fn drop(&mut self) {
288        // When a body stream is dropped, for whatever reason, attempt to send
289        // the body back to the `tx` which will provide the trailers if desired.
290        // This isn't necessary if the state is already closed. Additionally,
291        // like `record_frame` above, `send` errors are ignored as they indicate
292        // that the body/trailers aren't actually needed.
293        let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
294        if let IncomingBodyStreamState::Open { body, tx } = prev {
295            let _ = tx.send(StreamEnd::Remaining(body));
296        }
297    }
298}
299
300/// The concrete type behind a `wasi:http/types.future-trailers` resource.
301#[derive(Debug)]
302pub enum HostFutureTrailers {
303    /// Trailers aren't here yet.
304    ///
305    /// This state represents two similar states:
306    ///
307    /// * The body is here and ready for reading and we're waiting to read
308    ///   trailers. This can happen for example when the actual body wasn't read
309    ///   or if the body was only partially read.
310    ///
311    /// * The body is being read by something else and we're waiting for that to
312    ///   send us the trailers (or the body itself). This state will get entered
313    ///   when the body stream is dropped for example. If the body stream reads
314    ///   the trailers itself it will also send a message over here with the
315    ///   trailers.
316    Waiting(HostIncomingBody),
317
318    /// Trailers are ready and here they are.
319    ///
320    /// Note that `Ok(None)` means that there were no trailers for this request
321    /// while `Ok(Some(_))` means that trailers were found in the request.
322    Done(Result<Option<http::HeaderMap>, types::ErrorCode>),
323
324    /// Trailers have been consumed by `future-trailers.get`.
325    Consumed,
326}
327
328#[async_trait::async_trait]
329impl Pollable for HostFutureTrailers {
330    async fn ready(&mut self) {
331        let body = match self {
332            HostFutureTrailers::Waiting(body) => body,
333            HostFutureTrailers::Done(_) => return,
334            HostFutureTrailers::Consumed => return,
335        };
336
337        // If the body is itself being read by a body stream then we need to
338        // wait for that to be done.
339        if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
340            match rx.await {
341                // Trailers were read for us and here they are, so store the
342                // result.
343                Ok(StreamEnd::Trailers(Some(t))) => {
344                    *self = Self::Done(Ok(Some(t)));
345                }
346                // The body wasn't fully read and was dropped before trailers
347                // were reached. It's up to us now to complete the body.
348                Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
349
350                // This means there were no trailers present.
351                Ok(StreamEnd::Trailers(None)) | Err(_) => {
352                    *self = HostFutureTrailers::Done(Ok(None));
353                }
354            }
355        }
356
357        // Here it should be guaranteed that `InBodyStream` is now gone, so if
358        // we have the body ourselves then read frames until trailers are found.
359        let body = match self {
360            HostFutureTrailers::Waiting(body) => body,
361            HostFutureTrailers::Done(_) => return,
362            HostFutureTrailers::Consumed => return,
363        };
364        let hyper_body = match &mut body.body {
365            IncomingBodyState::Start(body) => body,
366            IncomingBodyState::InBodyStream(_) => unreachable!(),
367        };
368        let result = loop {
369            match hyper_body.frame().await {
370                None => break Ok(None),
371                Some(Err(e)) => break Err(e),
372                Some(Ok(frame)) => {
373                    // If this frame is a data frame ignore it as we're only
374                    // interested in trailers.
375                    if let Ok(header_map) = frame.into_trailers() {
376                        break Ok(Some(header_map));
377                    }
378                }
379            }
380        };
381        *self = HostFutureTrailers::Done(result);
382    }
383}
384
385#[derive(Debug, Clone)]
386struct WrittenState {
387    expected: u64,
388    written: Arc<std::sync::atomic::AtomicU64>,
389}
390
391impl WrittenState {
392    fn new(expected_size: u64) -> Self {
393        Self {
394            expected: expected_size,
395            written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
396        }
397    }
398
399    /// The number of bytes that have been written so far.
400    fn written(&self) -> u64 {
401        self.written.load(std::sync::atomic::Ordering::Relaxed)
402    }
403
404    /// Add `len` to the total number of bytes written. Returns `false` if the new total exceeds
405    /// the number of bytes expected to be written.
406    fn update(&self, len: usize) -> bool {
407        let len = len as u64;
408        let old = self
409            .written
410            .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
411        old + len <= self.expected
412    }
413}
414
415/// The concrete type behind a `wasi:http/types.outgoing-body` resource.
416pub struct HostOutgoingBody {
417    /// The output stream that the body is written to.
418    body_output_stream: Option<Box<dyn OutputStream>>,
419    context: StreamContext,
420    written: Option<WrittenState>,
421    finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
422}
423
424impl HostOutgoingBody {
425    /// Create a new `HostOutgoingBody`
426    pub fn new(
427        context: StreamContext,
428        size: Option<u64>,
429        buffer_chunks: usize,
430        chunk_size: usize,
431    ) -> (Self, HyperOutgoingBody) {
432        assert!(buffer_chunks >= 1);
433
434        let written = size.map(WrittenState::new);
435
436        use tokio::sync::oneshot::error::RecvError;
437        struct BodyImpl {
438            body_receiver: mpsc::Receiver<Bytes>,
439            finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
440        }
441        impl Body for BodyImpl {
442            type Data = Bytes;
443            type Error = types::ErrorCode;
444            fn poll_frame(
445                mut self: Pin<&mut Self>,
446                cx: &mut Context<'_>,
447            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
448                match self.as_mut().body_receiver.poll_recv(cx) {
449                    Poll::Pending => Poll::Pending,
450                    Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
451
452                    // This means that the `body_sender` end of the channel has been dropped.
453                    Poll::Ready(None) => {
454                        if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
455                            match Pin::new(&mut finish_receiver).poll(cx) {
456                                Poll::Pending => {
457                                    self.as_mut().finish_receiver = Some(finish_receiver);
458                                    Poll::Pending
459                                }
460                                Poll::Ready(Ok(message)) => match message {
461                                    FinishMessage::Finished => Poll::Ready(None),
462                                    FinishMessage::Trailers(trailers) => {
463                                        Poll::Ready(Some(Ok(Frame::trailers(trailers))))
464                                    }
465                                    FinishMessage::Abort => {
466                                        Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
467                                    }
468                                },
469                                Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
470                            }
471                        } else {
472                            Poll::Ready(None)
473                        }
474                    }
475                }
476            }
477        }
478
479        // always add 1 buffer here because one empty slot is required
480        let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
481        let (finish_sender, finish_receiver) = oneshot::channel();
482        let body_impl = BodyImpl {
483            body_receiver,
484            finish_receiver: Some(finish_receiver),
485        }
486        .boxed_unsync();
487
488        let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
489
490        (
491            Self {
492                body_output_stream: Some(Box::new(output_stream)),
493                context,
494                written,
495                finish_sender: Some(finish_sender),
496            },
497            body_impl,
498        )
499    }
500
501    /// Take the output stream, if it's available.
502    pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
503        self.body_output_stream.take()
504    }
505
506    /// Finish the body, optionally with trailers.
507    pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
508        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
509        // will immediately pick up the finish sender.
510        drop(self.body_output_stream);
511
512        let sender = self
513            .finish_sender
514            .take()
515            .expect("outgoing-body trailer_sender consumed by a non-owning function");
516
517        if let Some(w) = self.written {
518            let written = w.written();
519            if written != w.expected {
520                let _ = sender.send(FinishMessage::Abort);
521                return Err(self.context.as_body_size_error(written));
522            }
523        }
524
525        let message = if let Some(ts) = trailers {
526            FinishMessage::Trailers(ts.into())
527        } else {
528            FinishMessage::Finished
529        };
530
531        // Ignoring failure: receiver died sending body, but we can't report that here.
532        let _ = sender.send(message);
533
534        Ok(())
535    }
536
537    /// Abort the body.
538    pub fn abort(mut self) {
539        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
540        // will immediately pick up the finish sender.
541        drop(self.body_output_stream);
542
543        let sender = self
544            .finish_sender
545            .take()
546            .expect("outgoing-body trailer_sender consumed by a non-owning function");
547
548        let _ = sender.send(FinishMessage::Abort);
549    }
550}
551
552/// Message sent to end the `[HostOutgoingBody]` stream.
553#[derive(Debug)]
554enum FinishMessage {
555    Finished,
556    Trailers(hyper::HeaderMap),
557    Abort,
558}
559
560/// Whether the body is a request or response body.
561#[derive(Clone, Copy, Debug, Eq, PartialEq)]
562pub enum StreamContext {
563    /// The body is a request body.
564    Request,
565    /// The body is a response body.
566    Response,
567}
568
569impl StreamContext {
570    /// Construct the correct [`types::ErrorCode`] body size error.
571    pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
572        match self {
573            StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
574            StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
575        }
576    }
577}
578
579/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`].
580#[derive(Debug)]
581struct BodyWriteStream {
582    context: StreamContext,
583    writer: mpsc::Sender<Bytes>,
584    write_budget: usize,
585    written: Option<WrittenState>,
586}
587
588impl BodyWriteStream {
589    /// Create a [`BodyWriteStream`].
590    fn new(
591        context: StreamContext,
592        write_budget: usize,
593        writer: mpsc::Sender<Bytes>,
594        written: Option<WrittenState>,
595    ) -> Self {
596        // at least one capacity is required to send a message
597        assert!(writer.max_capacity() >= 1);
598        BodyWriteStream {
599            context,
600            writer,
601            write_budget,
602            written,
603        }
604    }
605}
606
607#[async_trait::async_trait]
608impl OutputStream for BodyWriteStream {
609    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
610        let len = bytes.len();
611        match self.writer.try_send(bytes) {
612            // If the message was sent then it's queued up now in hyper to get
613            // received.
614            Ok(()) => {
615                if let Some(written) = self.written.as_ref() {
616                    if !written.update(len) {
617                        let total = written.written();
618                        return Err(StreamError::LastOperationFailed(format_err!(
619                            self.context.as_body_size_error(total)
620                        )));
621                    }
622                }
623
624                Ok(())
625            }
626
627            // If this channel is full then that means `check_write` wasn't
628            // called. The call to `check_write` always guarantees that there's
629            // at least one capacity if a write is allowed.
630            Err(mpsc::error::TrySendError::Full(_)) => {
631                Err(StreamError::Trap(format_err!("write exceeded budget")))
632            }
633
634            // Hyper is gone so this stream is now closed.
635            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
636        }
637    }
638
639    fn flush(&mut self) -> Result<(), StreamError> {
640        // Flushing doesn't happen in this body stream since we're currently
641        // only tracking sending bytes over to hyper.
642        if self.writer.is_closed() {
643            Err(StreamError::Closed)
644        } else {
645            Ok(())
646        }
647    }
648
649    fn check_write(&mut self) -> Result<usize, StreamError> {
650        if self.writer.is_closed() {
651            Err(StreamError::Closed)
652        } else if self.writer.capacity() == 0 {
653            // If there is no more capacity in this sender channel then don't
654            // allow any more writes because the hyper task needs to catch up
655            // now.
656            //
657            // Note that this relies on this task being the only one sending
658            // data to ensure that no one else can steal a write into this
659            // channel.
660            Ok(0)
661        } else {
662            Ok(self.write_budget)
663        }
664    }
665}
666
667#[async_trait::async_trait]
668impl Pollable for BodyWriteStream {
669    async fn ready(&mut self) {
670        // Attempt to perform a reservation for a send. If there's capacity in
671        // the channel or it's already closed then this will return immediately.
672        // If the channel is full this will block until capacity opens up.
673        let _ = self.writer.reserve().await;
674    }
675}