1use crate::{bindings::http::types, types::FieldMap};
4use anyhow::anyhow;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::combinators::BoxBody;
8use http_body_util::BodyExt;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
15use wasmtime_wasi::runtime::{poll_noop, AbortOnDropJoinHandle};
16
17pub type HyperIncomingBody = BoxBody<Bytes, types::ErrorCode>;
19
20pub type HyperOutgoingBody = BoxBody<Bytes, types::ErrorCode>;
22
23#[derive(Debug)]
25pub struct HostIncomingBody {
26 body: IncomingBodyState,
27 worker: Option<AbortOnDropJoinHandle<()>>,
31}
32
33impl HostIncomingBody {
34 pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
36 let body = BodyWithTimeout::new(body, between_bytes_timeout);
37 HostIncomingBody {
38 body: IncomingBodyState::Start(body),
39 worker: None,
40 }
41 }
42
43 pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
45 assert!(self.worker.is_none());
46 self.worker = Some(worker);
47 }
48
49 pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
51 match &mut self.body {
52 IncomingBodyState::Start(_) => {}
53 IncomingBodyState::InBodyStream(_) => return None,
54 }
55 let (tx, rx) = oneshot::channel();
56 let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
57 IncomingBodyState::Start(b) => b,
58 IncomingBodyState::InBodyStream(_) => unreachable!(),
59 };
60 Some(HostIncomingBodyStream {
61 state: IncomingBodyStreamState::Open { body, tx },
62 buffer: Bytes::new(),
63 error: None,
64 })
65 }
66
67 pub fn into_future_trailers(self) -> HostFutureTrailers {
69 HostFutureTrailers::Waiting(self)
70 }
71}
72
73#[derive(Debug)]
75enum IncomingBodyState {
76 Start(BodyWithTimeout),
79
80 InBodyStream(oneshot::Receiver<StreamEnd>),
84}
85
86#[derive(Debug)]
88struct BodyWithTimeout {
89 inner: HyperIncomingBody,
91 timeout: Pin<Box<tokio::time::Sleep>>,
93 reset_sleep: bool,
96 between_bytes_timeout: Duration,
99}
100
101impl BodyWithTimeout {
102 fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
103 BodyWithTimeout {
104 inner,
105 between_bytes_timeout,
106 reset_sleep: true,
107 timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
108 tokio::time::sleep(Duration::new(0, 0))
109 })),
110 }
111 }
112}
113
114impl Body for BodyWithTimeout {
115 type Data = Bytes;
116 type Error = types::ErrorCode;
117
118 fn poll_frame(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
122 let me = Pin::into_inner(self);
123
124 if me.reset_sleep {
128 me.timeout
129 .as_mut()
130 .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
131 me.reset_sleep = false;
132 }
133
134 if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
137 return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
138 }
139
140 let result = Pin::new(&mut me.inner).poll_frame(cx);
143 me.reset_sleep = result.is_ready();
144 result
145 }
146}
147
148#[derive(Debug)]
151enum StreamEnd {
152 Remaining(BodyWithTimeout),
155
156 Trailers(Option<FieldMap>),
159}
160
161#[derive(Debug)]
164pub struct HostIncomingBodyStream {
165 state: IncomingBodyStreamState,
166 buffer: Bytes,
167 error: Option<anyhow::Error>,
168}
169
170impl HostIncomingBodyStream {
171 fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
172 match frame {
173 Some(Ok(frame)) => match frame.into_data() {
174 Ok(bytes) => {
177 assert!(self.buffer.is_empty());
178 self.buffer = bytes;
179 }
180
181 Err(trailers) => {
185 let trailers = trailers.into_trailers().unwrap();
186 let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
187 IncomingBodyStreamState::Open { body: _, tx } => tx,
188 IncomingBodyStreamState::Closed => unreachable!(),
189 };
190
191 let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
194 }
195 },
196
197 Some(Err(e)) => {
201 self.error = Some(e.into());
202 self.state = IncomingBodyStreamState::Closed;
203 }
204
205 None => {
209 self.state = IncomingBodyStreamState::Closed;
210 }
211 }
212 }
213}
214
215#[derive(Debug)]
216enum IncomingBodyStreamState {
217 Open {
225 body: BodyWithTimeout,
226 tx: oneshot::Sender<StreamEnd>,
227 },
228
229 Closed,
232}
233
234#[async_trait::async_trait]
235impl InputStream for HostIncomingBodyStream {
236 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
237 loop {
238 if !self.buffer.is_empty() {
240 let len = size.min(self.buffer.len());
241 let chunk = self.buffer.split_to(len);
242 return Ok(chunk);
243 }
244
245 if let Some(e) = self.error.take() {
246 return Err(StreamError::LastOperationFailed(e));
247 }
248
249 let body = match &mut self.state {
255 IncomingBodyStreamState::Open { body, .. } => body,
256 IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
257 };
258
259 let future = body.frame();
260 futures::pin_mut!(future);
261 match poll_noop(future) {
262 Some(result) => {
263 self.record_frame(result);
264 }
265 None => return Ok(Bytes::new()),
266 }
267 }
268 }
269}
270
271#[async_trait::async_trait]
272impl Pollable for HostIncomingBodyStream {
273 async fn ready(&mut self) {
274 if !self.buffer.is_empty() || self.error.is_some() {
275 return;
276 }
277
278 if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
279 let frame = body.frame().await;
280 self.record_frame(frame);
281 }
282 }
283}
284
285impl Drop for HostIncomingBodyStream {
286 fn drop(&mut self) {
287 let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
293 if let IncomingBodyStreamState::Open { body, tx } = prev {
294 let _ = tx.send(StreamEnd::Remaining(body));
295 }
296 }
297}
298
299#[derive(Debug)]
301pub enum HostFutureTrailers {
302 Waiting(HostIncomingBody),
316
317 Done(Result<Option<FieldMap>, types::ErrorCode>),
322
323 Consumed,
325}
326
327#[async_trait::async_trait]
328impl Pollable for HostFutureTrailers {
329 async fn ready(&mut self) {
330 let body = match self {
331 HostFutureTrailers::Waiting(body) => body,
332 HostFutureTrailers::Done(_) => return,
333 HostFutureTrailers::Consumed => return,
334 };
335
336 if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
339 match rx.await {
340 Ok(StreamEnd::Trailers(t)) => *self = Self::Done(Ok(t)),
343
344 Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
347
348 Err(_) => {
350 *self = HostFutureTrailers::Done(Ok(None));
351 }
352 }
353 }
354
355 let body = match self {
358 HostFutureTrailers::Waiting(body) => body,
359 HostFutureTrailers::Done(_) => return,
360 HostFutureTrailers::Consumed => return,
361 };
362 let hyper_body = match &mut body.body {
363 IncomingBodyState::Start(body) => body,
364 IncomingBodyState::InBodyStream(_) => unreachable!(),
365 };
366 let result = loop {
367 match hyper_body.frame().await {
368 None => break Ok(None),
369 Some(Err(e)) => break Err(e),
370 Some(Ok(frame)) => {
371 if let Ok(headers) = frame.into_trailers() {
374 break Ok(Some(headers));
375 }
376 }
377 }
378 };
379 *self = HostFutureTrailers::Done(result);
380 }
381}
382
383#[derive(Debug, Clone)]
384struct WrittenState {
385 expected: u64,
386 written: Arc<std::sync::atomic::AtomicU64>,
387}
388
389impl WrittenState {
390 fn new(expected_size: u64) -> Self {
391 Self {
392 expected: expected_size,
393 written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
394 }
395 }
396
397 fn written(&self) -> u64 {
399 self.written.load(std::sync::atomic::Ordering::Relaxed)
400 }
401
402 fn update(&self, len: usize) -> bool {
405 let len = len as u64;
406 let old = self
407 .written
408 .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
409 old + len <= self.expected
410 }
411}
412
413pub struct HostOutgoingBody {
415 body_output_stream: Option<Box<dyn OutputStream>>,
417 context: StreamContext,
418 written: Option<WrittenState>,
419 finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
420}
421
422impl HostOutgoingBody {
423 pub fn new(
425 context: StreamContext,
426 size: Option<u64>,
427 buffer_chunks: usize,
428 chunk_size: usize,
429 ) -> (Self, HyperOutgoingBody) {
430 assert!(buffer_chunks >= 1);
431
432 let written = size.map(WrittenState::new);
433
434 use tokio::sync::oneshot::error::RecvError;
435 struct BodyImpl {
436 body_receiver: mpsc::Receiver<Bytes>,
437 finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
438 }
439 impl Body for BodyImpl {
440 type Data = Bytes;
441 type Error = types::ErrorCode;
442 fn poll_frame(
443 mut self: Pin<&mut Self>,
444 cx: &mut Context<'_>,
445 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
446 match self.as_mut().body_receiver.poll_recv(cx) {
447 Poll::Pending => Poll::Pending,
448 Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
449
450 Poll::Ready(None) => {
452 if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
453 match Pin::new(&mut finish_receiver).poll(cx) {
454 Poll::Pending => {
455 self.as_mut().finish_receiver = Some(finish_receiver);
456 Poll::Pending
457 }
458 Poll::Ready(Ok(message)) => match message {
459 FinishMessage::Finished => Poll::Ready(None),
460 FinishMessage::Trailers(trailers) => {
461 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
462 }
463 FinishMessage::Abort => {
464 Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
465 }
466 },
467 Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
468 }
469 } else {
470 Poll::Ready(None)
471 }
472 }
473 }
474 }
475 }
476
477 let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
479 let (finish_sender, finish_receiver) = oneshot::channel();
480 let body_impl = BodyImpl {
481 body_receiver,
482 finish_receiver: Some(finish_receiver),
483 }
484 .boxed();
485
486 let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
487
488 (
489 Self {
490 body_output_stream: Some(Box::new(output_stream)),
491 context,
492 written,
493 finish_sender: Some(finish_sender),
494 },
495 body_impl,
496 )
497 }
498
499 pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
501 self.body_output_stream.take()
502 }
503
504 pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
506 drop(self.body_output_stream);
509
510 let sender = self
511 .finish_sender
512 .take()
513 .expect("outgoing-body trailer_sender consumed by a non-owning function");
514
515 if let Some(w) = self.written {
516 let written = w.written();
517 if written != w.expected {
518 let _ = sender.send(FinishMessage::Abort);
519 return Err(self.context.as_body_size_error(written));
520 }
521 }
522
523 let message = if let Some(ts) = trailers {
524 FinishMessage::Trailers(ts)
525 } else {
526 FinishMessage::Finished
527 };
528
529 let _ = sender.send(message.into());
531
532 Ok(())
533 }
534
535 pub fn abort(mut self) {
537 drop(self.body_output_stream);
540
541 let sender = self
542 .finish_sender
543 .take()
544 .expect("outgoing-body trailer_sender consumed by a non-owning function");
545
546 let _ = sender.send(FinishMessage::Abort);
547 }
548}
549
550#[derive(Debug)]
552enum FinishMessage {
553 Finished,
554 Trailers(hyper::HeaderMap),
555 Abort,
556}
557
558#[derive(Clone, Copy, Debug, Eq, PartialEq)]
560pub enum StreamContext {
561 Request,
563 Response,
565}
566
567impl StreamContext {
568 pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
570 match self {
571 StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
572 StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
573 }
574 }
575}
576
577#[derive(Debug)]
579struct BodyWriteStream {
580 context: StreamContext,
581 writer: mpsc::Sender<Bytes>,
582 write_budget: usize,
583 written: Option<WrittenState>,
584}
585
586impl BodyWriteStream {
587 fn new(
589 context: StreamContext,
590 write_budget: usize,
591 writer: mpsc::Sender<Bytes>,
592 written: Option<WrittenState>,
593 ) -> Self {
594 assert!(writer.max_capacity() >= 1);
596 BodyWriteStream {
597 context,
598 writer,
599 write_budget,
600 written,
601 }
602 }
603}
604
605#[async_trait::async_trait]
606impl OutputStream for BodyWriteStream {
607 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
608 let len = bytes.len();
609 match self.writer.try_send(bytes) {
610 Ok(()) => {
613 if let Some(written) = self.written.as_ref() {
614 if !written.update(len) {
615 let total = written.written();
616 return Err(StreamError::LastOperationFailed(anyhow!(self
617 .context
618 .as_body_size_error(total))));
619 }
620 }
621
622 Ok(())
623 }
624
625 Err(mpsc::error::TrySendError::Full(_)) => {
629 Err(StreamError::Trap(anyhow!("write exceeded budget")))
630 }
631
632 Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
634 }
635 }
636
637 fn flush(&mut self) -> Result<(), StreamError> {
638 if self.writer.is_closed() {
641 Err(StreamError::Closed)
642 } else {
643 Ok(())
644 }
645 }
646
647 fn check_write(&mut self) -> Result<usize, StreamError> {
648 if self.writer.is_closed() {
649 Err(StreamError::Closed)
650 } else if self.writer.capacity() == 0 {
651 Ok(0)
659 } else {
660 Ok(self.write_budget)
661 }
662 }
663}
664
665#[async_trait::async_trait]
666impl Pollable for BodyWriteStream {
667 async fn ready(&mut self) {
668 let _ = self.writer.reserve().await;
672 }
673}