wasmtime_wasi_http/
body.rs

1//! Implementation of the `wasi:http/types` interface's various body types.
2
3use crate::{bindings::http::types, types::FieldMap};
4use anyhow::anyhow;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::combinators::BoxBody;
8use http_body_util::BodyExt;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
15use wasmtime_wasi::runtime::{poll_noop, AbortOnDropJoinHandle};
16
17/// Common type for incoming bodies.
18pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
19
20/// Common type for outgoing bodies.
21pub type HyperOutgoingBody = BoxBody<Bytes, types::ErrorCode>;
22
23/// The concrete type behind a `was:http/types/incoming-body` resource.
24#[derive(Debug)]
25pub struct HostIncomingBody {
26    body: IncomingBodyState,
27    /// An optional worker task to keep alive while this body is being read.
28    /// This ensures that if the parent of this body is dropped before the body
29    /// then the backing data behind this worker is kept alive.
30    worker: Option<AbortOnDropJoinHandle<()>>,
31}
32
33impl HostIncomingBody {
34    /// Create a new `HostIncomingBody` with the given `body` and a per-frame timeout
35    pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
36        let body = BodyWithTimeout::new(body, between_bytes_timeout);
37        HostIncomingBody {
38            body: IncomingBodyState::Start(body),
39            worker: None,
40        }
41    }
42
43    /// Retain a worker task that needs to be kept alive while this body is being read.
44    pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
45        assert!(self.worker.is_none());
46        self.worker = Some(worker);
47    }
48
49    /// Try taking the stream of this body, if it's available.
50    pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
51        match &mut self.body {
52            IncomingBodyState::Start(_) => {}
53            IncomingBodyState::InBodyStream(_) => return None,
54        }
55        let (tx, rx) = oneshot::channel();
56        let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
57            IncomingBodyState::Start(b) => b,
58            IncomingBodyState::InBodyStream(_) => unreachable!(),
59        };
60        Some(HostIncomingBodyStream {
61            state: IncomingBodyStreamState::Open { body, tx },
62            buffer: Bytes::new(),
63            error: None,
64        })
65    }
66
67    /// Convert this body into a `HostFutureTrailers` resource.
68    pub fn into_future_trailers(self) -> HostFutureTrailers {
69        HostFutureTrailers::Waiting(self)
70    }
71}
72
73/// Internal state of a [`HostIncomingBody`].
74#[derive(Debug)]
75enum IncomingBodyState {
76    /// The body is stored here meaning that within `HostIncomingBody` the
77    /// `take_stream` method can be called for example.
78    Start(BodyWithTimeout),
79
80    /// The body is within a `HostIncomingBodyStream` meaning that it's not
81    /// currently owned here. The body will be sent back over this channel when
82    /// it's done, however.
83    InBodyStream(oneshot::Receiver<StreamEnd>),
84}
85
86/// Small wrapper around [`HyperIncomingBody`] which adds a timeout to every frame.
87#[derive(Debug)]
88struct BodyWithTimeout {
89    /// Underlying stream that frames are coming from.
90    inner: HyperIncomingBody,
91    /// Currently active timeout that's reset between frames.
92    timeout: Pin<Box<tokio::time::Sleep>>,
93    /// Whether or not `timeout` needs to be reset on the next call to
94    /// `poll_frame`.
95    reset_sleep: bool,
96    /// Maximal duration between when a frame is first requested and when it's
97    /// allowed to arrive.
98    between_bytes_timeout: Duration,
99}
100
101impl BodyWithTimeout {
102    fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
103        BodyWithTimeout {
104            inner,
105            between_bytes_timeout,
106            reset_sleep: true,
107            timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
108                tokio::time::sleep(Duration::new(0, 0))
109            })),
110        }
111    }
112}
113
114impl Body for BodyWithTimeout {
115    type Data = Bytes;
116    type Error = types::ErrorCode;
117
118    fn poll_frame(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121    ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
122        let me = Pin::into_inner(self);
123
124        // If the timeout timer needs to be reset, do that now relative to the
125        // current instant. Otherwise test the timeout timer and see if it's
126        // fired yet and if so we've timed out and return an error.
127        if me.reset_sleep {
128            me.timeout
129                .as_mut()
130                .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
131            me.reset_sleep = false;
132        }
133
134        // Register interest in this context on the sleep timer, and if the
135        // sleep elapsed that means that we've timed out.
136        if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
137            return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
138        }
139
140        // Without timeout business now handled check for the frame. If a frame
141        // arrives then the sleep timer will be reset on the next frame.
142        let result = Pin::new(&mut me.inner).poll_frame(cx);
143        me.reset_sleep = result.is_ready();
144        result
145    }
146}
147
148/// Message sent when a `HostIncomingBodyStream` is done to the
149/// `HostFutureTrailers` state.
150#[derive(Debug)]
151enum StreamEnd {
152    /// The body wasn't completely read and was dropped early. May still have
153    /// trailers, but requires reading more frames.
154    Remaining(BodyWithTimeout),
155
156    /// Body was completely read and trailers were read. Here are the trailers.
157    /// Note that `None` means that the body finished without trailers.
158    Trailers(Option<FieldMap>),
159}
160
161/// The concrete type behind the `wasi:io/streams/input-stream` resource returned
162/// by `wasi:http/types/incoming-body`'s `stream` method.
163#[derive(Debug)]
164pub struct HostIncomingBodyStream {
165    state: IncomingBodyStreamState,
166    buffer: Bytes,
167    error: Option<anyhow::Error>,
168}
169
170impl HostIncomingBodyStream {
171    fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
172        match frame {
173            Some(Ok(frame)) => match frame.into_data() {
174                // A data frame was received, so queue up the buffered data for
175                // the next `read` call.
176                Ok(bytes) => {
177                    assert!(self.buffer.is_empty());
178                    self.buffer = bytes;
179                }
180
181                // Trailers were received meaning that this was the final frame.
182                // Throw away the body and send the trailers along the
183                // `tx` channel to make them available.
184                Err(trailers) => {
185                    let trailers = trailers.into_trailers().unwrap();
186                    let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
187                        IncomingBodyStreamState::Open { body: _, tx } => tx,
188                        IncomingBodyStreamState::Closed => unreachable!(),
189                    };
190
191                    // NB: ignore send failures here because if this fails then
192                    // no one was interested in the trailers.
193                    let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
194                }
195            },
196
197            // An error was received meaning that the stream is now done.
198            // Destroy the body to terminate the stream while enqueueing the
199            // error to get returned from the next call to `read`.
200            Some(Err(e)) => {
201                self.error = Some(e.into());
202                self.state = IncomingBodyStreamState::Closed;
203            }
204
205            // No more frames are going to be received again, so drop the `body`
206            // and the `tx` channel we'd send the body back onto because it's
207            // not needed as frames are done.
208            None => {
209                self.state = IncomingBodyStreamState::Closed;
210            }
211        }
212    }
213}
214
215#[derive(Debug)]
216enum IncomingBodyStreamState {
217    /// The body is currently open for reading and present here.
218    ///
219    /// When trailers are read, or when this is dropped, the body is sent along
220    /// `tx`.
221    ///
222    /// This state is transitioned to `Closed` when an error happens, EOF
223    /// happens, or when trailers are read.
224    Open {
225        body: BodyWithTimeout,
226        tx: oneshot::Sender<StreamEnd>,
227    },
228
229    /// This body is closed and no longer available for reading, no more data
230    /// will come.
231    Closed,
232}
233
234#[async_trait::async_trait]
235impl InputStream for HostIncomingBodyStream {
236    fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
237        loop {
238            // Handle buffered data/errors if any
239            if !self.buffer.is_empty() {
240                let len = size.min(self.buffer.len());
241                let chunk = self.buffer.split_to(len);
242                return Ok(chunk);
243            }
244
245            if let Some(e) = self.error.take() {
246                return Err(StreamError::LastOperationFailed(e));
247            }
248
249            // Extract the body that we're reading from. If present perform a
250            // non-blocking poll to see if a frame is already here. If it is
251            // then turn the loop again to operate on the results. If it's not
252            // here then return an empty buffer as no data is available at this
253            // time.
254            let body = match &mut self.state {
255                IncomingBodyStreamState::Open { body, .. } => body,
256                IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
257            };
258
259            let future = body.frame();
260            futures::pin_mut!(future);
261            match poll_noop(future) {
262                Some(result) => {
263                    self.record_frame(result);
264                }
265                None => return Ok(Bytes::new()),
266            }
267        }
268    }
269}
270
271#[async_trait::async_trait]
272impl Pollable for HostIncomingBodyStream {
273    async fn ready(&mut self) {
274        if !self.buffer.is_empty() || self.error.is_some() {
275            return;
276        }
277
278        if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
279            let frame = body.frame().await;
280            self.record_frame(frame);
281        }
282    }
283}
284
285impl Drop for HostIncomingBodyStream {
286    fn drop(&mut self) {
287        // When a body stream is dropped, for whatever reason, attempt to send
288        // the body back to the `tx` which will provide the trailers if desired.
289        // This isn't necessary if the state is already closed. Additionally,
290        // like `record_frame` above, `send` errors are ignored as they indicate
291        // that the body/trailers aren't actually needed.
292        let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
293        if let IncomingBodyStreamState::Open { body, tx } = prev {
294            let _ = tx.send(StreamEnd::Remaining(body));
295        }
296    }
297}
298
299/// The concrete type behind a `wasi:http/types/future-trailers` resource.
300#[derive(Debug)]
301pub enum HostFutureTrailers {
302    /// Trailers aren't here yet.
303    ///
304    /// This state represents two similar states:
305    ///
306    /// * The body is here and ready for reading and we're waiting to read
307    ///   trailers. This can happen for example when the actual body wasn't read
308    ///   or if the body was only partially read.
309    ///
310    /// * The body is being read by something else and we're waiting for that to
311    ///   send us the trailers (or the body itself). This state will get entered
312    ///   when the body stream is dropped for example. If the body stream reads
313    ///   the trailers itself it will also send a message over here with the
314    ///   trailers.
315    Waiting(HostIncomingBody),
316
317    /// Trailers are ready and here they are.
318    ///
319    /// Note that `Ok(None)` means that there were no trailers for this request
320    /// while `Ok(Some(_))` means that trailers were found in the request.
321    Done(Result<Option<FieldMap>, types::ErrorCode>),
322
323    /// Trailers have been consumed by `future-trailers.get`.
324    Consumed,
325}
326
327#[async_trait::async_trait]
328impl Pollable for HostFutureTrailers {
329    async fn ready(&mut self) {
330        let body = match self {
331            HostFutureTrailers::Waiting(body) => body,
332            HostFutureTrailers::Done(_) => return,
333            HostFutureTrailers::Consumed => return,
334        };
335
336        // If the body is itself being read by a body stream then we need to
337        // wait for that to be done.
338        if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
339            match rx.await {
340                // Trailers were read for us and here they are, so store the
341                // result.
342                Ok(StreamEnd::Trailers(t)) => *self = Self::Done(Ok(t)),
343
344                // The body wasn't fully read and was dropped before trailers
345                // were reached. It's up to us now to complete the body.
346                Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
347
348                // This means there were no trailers present.
349                Err(_) => {
350                    *self = HostFutureTrailers::Done(Ok(None));
351                }
352            }
353        }
354
355        // Here it should be guaranteed that `InBodyStream` is now gone, so if
356        // we have the body ourselves then read frames until trailers are found.
357        let body = match self {
358            HostFutureTrailers::Waiting(body) => body,
359            HostFutureTrailers::Done(_) => return,
360            HostFutureTrailers::Consumed => return,
361        };
362        let hyper_body = match &mut body.body {
363            IncomingBodyState::Start(body) => body,
364            IncomingBodyState::InBodyStream(_) => unreachable!(),
365        };
366        let result = loop {
367            match hyper_body.frame().await {
368                None => break Ok(None),
369                Some(Err(e)) => break Err(e),
370                Some(Ok(frame)) => {
371                    // If this frame is a data frame ignore it as we're only
372                    // interested in trailers.
373                    if let Ok(headers) = frame.into_trailers() {
374                        break Ok(Some(headers));
375                    }
376                }
377            }
378        };
379        *self = HostFutureTrailers::Done(result);
380    }
381}
382
383#[derive(Debug, Clone)]
384struct WrittenState {
385    expected: u64,
386    written: Arc<std::sync::atomic::AtomicU64>,
387}
388
389impl WrittenState {
390    fn new(expected_size: u64) -> Self {
391        Self {
392            expected: expected_size,
393            written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
394        }
395    }
396
397    /// The number of bytes that have been written so far.
398    fn written(&self) -> u64 {
399        self.written.load(std::sync::atomic::Ordering::Relaxed)
400    }
401
402    /// Add `len` to the total number of bytes written. Returns `false` if the new total exceeds
403    /// the number of bytes expected to be written.
404    fn update(&self, len: usize) -> bool {
405        let len = len as u64;
406        let old = self
407            .written
408            .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
409        old + len <= self.expected
410    }
411}
412
413/// The concrete type behind a `wasi:http/types/outgoing-body` resource.
414pub struct HostOutgoingBody {
415    /// The output stream that the body is written to.
416    body_output_stream: Option<Box<dyn OutputStream>>,
417    context: StreamContext,
418    written: Option<WrittenState>,
419    finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
420}
421
422impl HostOutgoingBody {
423    /// Create a new `HostOutgoingBody`
424    pub fn new(
425        context: StreamContext,
426        size: Option<u64>,
427        buffer_chunks: usize,
428        chunk_size: usize,
429    ) -> (Self, HyperOutgoingBody) {
430        assert!(buffer_chunks >= 1);
431
432        let written = size.map(WrittenState::new);
433
434        use tokio::sync::oneshot::error::RecvError;
435        struct BodyImpl {
436            body_receiver: mpsc::Receiver<Bytes>,
437            finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
438        }
439        impl Body for BodyImpl {
440            type Data = Bytes;
441            type Error = types::ErrorCode;
442            fn poll_frame(
443                mut self: Pin<&mut Self>,
444                cx: &mut Context<'_>,
445            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
446                match self.as_mut().body_receiver.poll_recv(cx) {
447                    Poll::Pending => Poll::Pending,
448                    Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
449
450                    // This means that the `body_sender` end of the channel has been dropped.
451                    Poll::Ready(None) => {
452                        if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
453                            match Pin::new(&mut finish_receiver).poll(cx) {
454                                Poll::Pending => {
455                                    self.as_mut().finish_receiver = Some(finish_receiver);
456                                    Poll::Pending
457                                }
458                                Poll::Ready(Ok(message)) => match message {
459                                    FinishMessage::Finished => Poll::Ready(None),
460                                    FinishMessage::Trailers(trailers) => {
461                                        Poll::Ready(Some(Ok(Frame::trailers(trailers))))
462                                    }
463                                    FinishMessage::Abort => {
464                                        Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
465                                    }
466                                },
467                                Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
468                            }
469                        } else {
470                            Poll::Ready(None)
471                        }
472                    }
473                }
474            }
475        }
476
477        // always add 1 buffer here because one empty slot is required
478        let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
479        let (finish_sender, finish_receiver) = oneshot::channel();
480        let body_impl = BodyImpl {
481            body_receiver,
482            finish_receiver: Some(finish_receiver),
483        }
484        .boxed();
485
486        let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
487
488        (
489            Self {
490                body_output_stream: Some(Box::new(output_stream)),
491                context,
492                written,
493                finish_sender: Some(finish_sender),
494            },
495            body_impl,
496        )
497    }
498
499    /// Take the output stream, if it's available.
500    pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
501        self.body_output_stream.take()
502    }
503
504    /// Finish the body, optionally with trailers.
505    pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
506        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
507        // will immediately pick up the finish sender.
508        drop(self.body_output_stream);
509
510        let sender = self
511            .finish_sender
512            .take()
513            .expect("outgoing-body trailer_sender consumed by a non-owning function");
514
515        if let Some(w) = self.written {
516            let written = w.written();
517            if written != w.expected {
518                let _ = sender.send(FinishMessage::Abort);
519                return Err(self.context.as_body_size_error(written));
520            }
521        }
522
523        let message = if let Some(ts) = trailers {
524            FinishMessage::Trailers(ts)
525        } else {
526            FinishMessage::Finished
527        };
528
529        // Ignoring failure: receiver died sending body, but we can't report that here.
530        let _ = sender.send(message.into());
531
532        Ok(())
533    }
534
535    /// Abort the body.
536    pub fn abort(mut self) {
537        // Make sure that the output stream has been dropped, so that the BodyImpl poll function
538        // will immediately pick up the finish sender.
539        drop(self.body_output_stream);
540
541        let sender = self
542            .finish_sender
543            .take()
544            .expect("outgoing-body trailer_sender consumed by a non-owning function");
545
546        let _ = sender.send(FinishMessage::Abort);
547    }
548}
549
550/// Message sent to end the `[HostOutgoingBody]` stream.
551#[derive(Debug)]
552enum FinishMessage {
553    Finished,
554    Trailers(hyper::HeaderMap),
555    Abort,
556}
557
558/// Whether the body is a request or response body.
559#[derive(Clone, Copy, Debug, Eq, PartialEq)]
560pub enum StreamContext {
561    /// The body is a request body.
562    Request,
563    /// The body is a response body.
564    Response,
565}
566
567impl StreamContext {
568    /// Construct the correct [`types::ErrorCode`] body size error.
569    pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
570        match self {
571            StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
572            StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
573        }
574    }
575}
576
577/// Provides a [`HostOutputStream`] impl from a [`tokio::sync::mpsc::Sender`].
578#[derive(Debug)]
579struct BodyWriteStream {
580    context: StreamContext,
581    writer: mpsc::Sender<Bytes>,
582    write_budget: usize,
583    written: Option<WrittenState>,
584}
585
586impl BodyWriteStream {
587    /// Create a [`BodyWriteStream`].
588    fn new(
589        context: StreamContext,
590        write_budget: usize,
591        writer: mpsc::Sender<Bytes>,
592        written: Option<WrittenState>,
593    ) -> Self {
594        // at least one capacity is required to send a message
595        assert!(writer.max_capacity() >= 1);
596        BodyWriteStream {
597            context,
598            writer,
599            write_budget,
600            written,
601        }
602    }
603}
604
605#[async_trait::async_trait]
606impl OutputStream for BodyWriteStream {
607    fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
608        let len = bytes.len();
609        match self.writer.try_send(bytes) {
610            // If the message was sent then it's queued up now in hyper to get
611            // received.
612            Ok(()) => {
613                if let Some(written) = self.written.as_ref() {
614                    if !written.update(len) {
615                        let total = written.written();
616                        return Err(StreamError::LastOperationFailed(anyhow!(self
617                            .context
618                            .as_body_size_error(total))));
619                    }
620                }
621
622                Ok(())
623            }
624
625            // If this channel is full then that means `check_write` wasn't
626            // called. The call to `check_write` always guarantees that there's
627            // at least one capacity if a write is allowed.
628            Err(mpsc::error::TrySendError::Full(_)) => {
629                Err(StreamError::Trap(anyhow!("write exceeded budget")))
630            }
631
632            // Hyper is gone so this stream is now closed.
633            Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
634        }
635    }
636
637    fn flush(&mut self) -> Result<(), StreamError> {
638        // Flushing doesn't happen in this body stream since we're currently
639        // only tracking sending bytes over to hyper.
640        if self.writer.is_closed() {
641            Err(StreamError::Closed)
642        } else {
643            Ok(())
644        }
645    }
646
647    fn check_write(&mut self) -> Result<usize, StreamError> {
648        if self.writer.is_closed() {
649            Err(StreamError::Closed)
650        } else if self.writer.capacity() == 0 {
651            // If there is no more capacity in this sender channel then don't
652            // allow any more writes because the hyper task needs to catch up
653            // now.
654            //
655            // Note that this relies on this task being the only one sending
656            // data to ensure that no one else can steal a write into this
657            // channel.
658            Ok(0)
659        } else {
660            Ok(self.write_budget)
661        }
662    }
663}
664
665#[async_trait::async_trait]
666impl Pollable for BodyWriteStream {
667    async fn ready(&mut self) {
668        // Attempt to perform a reservation for a send. If there's capacity in
669        // the channel or it's already closed then this will return immediately.
670        // If the channel is full this will block until capacity opens up.
671        let _ = self.writer.reserve().await;
672    }
673}