1use crate::FieldMap;
4use crate::p2::bindings::http::types;
5use bytes::Bytes;
6use http_body::{Body, Frame};
7use http_body_util::BodyExt;
8use http_body_util::combinators::UnsyncBoxBody;
9use std::future::Future;
10use std::mem;
11use std::task::{Context, Poll};
12use std::{pin::Pin, sync::Arc, time::Duration};
13use tokio::sync::{mpsc, oneshot};
14use wasmtime::format_err;
15use wasmtime_wasi::p2::{InputStream, OutputStream, Pollable, StreamError};
16use wasmtime_wasi::runtime::{AbortOnDropJoinHandle, poll_noop};
17
18pub type HyperIncomingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
20
21pub type HyperOutgoingBody = UnsyncBoxBody<Bytes, types::ErrorCode>;
23
24#[derive(Debug)]
26pub struct HostIncomingBody {
27 body: IncomingBodyState,
28 worker: Option<AbortOnDropJoinHandle<()>>,
32}
33
34impl HostIncomingBody {
35 pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody {
37 let body = BodyWithTimeout::new(body, between_bytes_timeout);
38 HostIncomingBody {
39 body: IncomingBodyState::Start(body),
40 worker: None,
41 }
42 }
43
44 pub fn retain_worker(&mut self, worker: AbortOnDropJoinHandle<()>) {
46 assert!(self.worker.is_none());
47 self.worker = Some(worker);
48 }
49
50 pub fn take_stream(&mut self) -> Option<HostIncomingBodyStream> {
52 match &mut self.body {
53 IncomingBodyState::Start(_) => {}
54 IncomingBodyState::InBodyStream(_) => return None,
55 }
56 let (tx, rx) = oneshot::channel();
57 let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) {
58 IncomingBodyState::Start(b) => b,
59 IncomingBodyState::InBodyStream(_) => unreachable!(),
60 };
61 Some(HostIncomingBodyStream {
62 state: IncomingBodyStreamState::Open { body, tx },
63 buffer: Bytes::new(),
64 error: None,
65 })
66 }
67
68 pub fn into_future_trailers(self) -> HostFutureTrailers {
70 HostFutureTrailers::Waiting(self)
71 }
72}
73
74#[derive(Debug)]
76enum IncomingBodyState {
77 Start(BodyWithTimeout),
80
81 InBodyStream(oneshot::Receiver<StreamEnd>),
85}
86
87#[derive(Debug)]
89struct BodyWithTimeout {
90 inner: HyperIncomingBody,
92 timeout: Pin<Box<tokio::time::Sleep>>,
94 reset_sleep: bool,
97 between_bytes_timeout: Duration,
100}
101
102impl BodyWithTimeout {
103 fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout {
104 BodyWithTimeout {
105 inner,
106 between_bytes_timeout,
107 reset_sleep: true,
108 timeout: Box::pin(wasmtime_wasi::runtime::with_ambient_tokio_runtime(|| {
109 tokio::time::sleep(Duration::new(0, 0))
110 })),
111 }
112 }
113}
114
115impl Body for BodyWithTimeout {
116 type Data = Bytes;
117 type Error = types::ErrorCode;
118
119 fn poll_frame(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 ) -> Poll<Option<Result<Frame<Bytes>, types::ErrorCode>>> {
123 let me = Pin::into_inner(self);
124
125 if me.reset_sleep {
129 me.timeout
130 .as_mut()
131 .reset(tokio::time::Instant::now() + me.between_bytes_timeout);
132 me.reset_sleep = false;
133 }
134
135 if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) {
138 return Poll::Ready(Some(Err(types::ErrorCode::ConnectionReadTimeout)));
139 }
140
141 let result = Pin::new(&mut me.inner).poll_frame(cx);
144 me.reset_sleep = result.is_ready();
145 result
146 }
147}
148
149#[derive(Debug)]
152enum StreamEnd {
153 Remaining(BodyWithTimeout),
156
157 Trailers(Option<http::HeaderMap>),
160}
161
162#[derive(Debug)]
165pub struct HostIncomingBodyStream {
166 state: IncomingBodyStreamState,
167 buffer: Bytes,
168 error: Option<wasmtime::Error>,
169}
170
171impl HostIncomingBodyStream {
172 fn record_frame(&mut self, frame: Option<Result<Frame<Bytes>, types::ErrorCode>>) {
173 match frame {
174 Some(Ok(frame)) => match frame.into_data() {
175 Ok(bytes) => {
178 assert!(self.buffer.is_empty());
179 self.buffer = bytes;
180 }
181
182 Err(trailers) => {
186 let trailers = trailers.into_trailers().unwrap();
187 let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) {
188 IncomingBodyStreamState::Open { body: _, tx } => tx,
189 IncomingBodyStreamState::Closed => unreachable!(),
190 };
191
192 let _ = tx.send(StreamEnd::Trailers(Some(trailers)));
195 }
196 },
197
198 Some(Err(e)) => {
202 self.error = Some(e.into());
203 self.state = IncomingBodyStreamState::Closed;
204 }
205
206 None => {
210 self.state = IncomingBodyStreamState::Closed;
211 }
212 }
213 }
214}
215
216#[derive(Debug)]
217enum IncomingBodyStreamState {
218 Open {
226 body: BodyWithTimeout,
227 tx: oneshot::Sender<StreamEnd>,
228 },
229
230 Closed,
233}
234
235#[async_trait::async_trait]
236impl InputStream for HostIncomingBodyStream {
237 fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
238 loop {
239 if !self.buffer.is_empty() {
241 let len = size.min(self.buffer.len());
242 let chunk = self.buffer.split_to(len);
243 return Ok(chunk);
244 }
245
246 if let Some(e) = self.error.take() {
247 return Err(StreamError::LastOperationFailed(e));
248 }
249
250 let body = match &mut self.state {
256 IncomingBodyStreamState::Open { body, .. } => body,
257 IncomingBodyStreamState::Closed => return Err(StreamError::Closed),
258 };
259
260 let future = body.frame();
261 futures::pin_mut!(future);
262 match poll_noop(future) {
263 Some(result) => {
264 self.record_frame(result);
265 }
266 None => return Ok(Bytes::new()),
267 }
268 }
269 }
270}
271
272#[async_trait::async_trait]
273impl Pollable for HostIncomingBodyStream {
274 async fn ready(&mut self) {
275 if !self.buffer.is_empty() || self.error.is_some() {
276 return;
277 }
278
279 if let IncomingBodyStreamState::Open { body, .. } = &mut self.state {
280 let frame = body.frame().await;
281 self.record_frame(frame);
282 }
283 }
284}
285
286impl Drop for HostIncomingBodyStream {
287 fn drop(&mut self) {
288 let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed);
294 if let IncomingBodyStreamState::Open { body, tx } = prev {
295 let _ = tx.send(StreamEnd::Remaining(body));
296 }
297 }
298}
299
300#[derive(Debug)]
302pub enum HostFutureTrailers {
303 Waiting(HostIncomingBody),
317
318 Done(Result<Option<http::HeaderMap>, types::ErrorCode>),
323
324 Consumed,
326}
327
328#[async_trait::async_trait]
329impl Pollable for HostFutureTrailers {
330 async fn ready(&mut self) {
331 let body = match self {
332 HostFutureTrailers::Waiting(body) => body,
333 HostFutureTrailers::Done(_) => return,
334 HostFutureTrailers::Consumed => return,
335 };
336
337 if let IncomingBodyState::InBodyStream(rx) = &mut body.body {
340 match rx.await {
341 Ok(StreamEnd::Trailers(Some(t))) => {
344 *self = Self::Done(Ok(Some(t)));
345 }
346 Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b),
349
350 Ok(StreamEnd::Trailers(None)) | Err(_) => {
352 *self = HostFutureTrailers::Done(Ok(None));
353 }
354 }
355 }
356
357 let body = match self {
360 HostFutureTrailers::Waiting(body) => body,
361 HostFutureTrailers::Done(_) => return,
362 HostFutureTrailers::Consumed => return,
363 };
364 let hyper_body = match &mut body.body {
365 IncomingBodyState::Start(body) => body,
366 IncomingBodyState::InBodyStream(_) => unreachable!(),
367 };
368 let result = loop {
369 match hyper_body.frame().await {
370 None => break Ok(None),
371 Some(Err(e)) => break Err(e),
372 Some(Ok(frame)) => {
373 if let Ok(header_map) = frame.into_trailers() {
376 break Ok(Some(header_map));
377 }
378 }
379 }
380 };
381 *self = HostFutureTrailers::Done(result);
382 }
383}
384
385#[derive(Debug, Clone)]
386struct WrittenState {
387 expected: u64,
388 written: Arc<std::sync::atomic::AtomicU64>,
389}
390
391impl WrittenState {
392 fn new(expected_size: u64) -> Self {
393 Self {
394 expected: expected_size,
395 written: Arc::new(std::sync::atomic::AtomicU64::new(0)),
396 }
397 }
398
399 fn written(&self) -> u64 {
401 self.written.load(std::sync::atomic::Ordering::Relaxed)
402 }
403
404 fn update(&self, len: usize) -> bool {
407 let len = len as u64;
408 let old = self
409 .written
410 .fetch_add(len, std::sync::atomic::Ordering::Relaxed);
411 old + len <= self.expected
412 }
413}
414
415pub struct HostOutgoingBody {
417 body_output_stream: Option<Box<dyn OutputStream>>,
419 context: StreamContext,
420 written: Option<WrittenState>,
421 finish_sender: Option<tokio::sync::oneshot::Sender<FinishMessage>>,
422}
423
424impl HostOutgoingBody {
425 pub fn new(
427 context: StreamContext,
428 size: Option<u64>,
429 buffer_chunks: usize,
430 chunk_size: usize,
431 ) -> (Self, HyperOutgoingBody) {
432 assert!(buffer_chunks >= 1);
433
434 let written = size.map(WrittenState::new);
435
436 use tokio::sync::oneshot::error::RecvError;
437 struct BodyImpl {
438 body_receiver: mpsc::Receiver<Bytes>,
439 finish_receiver: Option<oneshot::Receiver<FinishMessage>>,
440 }
441 impl Body for BodyImpl {
442 type Data = Bytes;
443 type Error = types::ErrorCode;
444 fn poll_frame(
445 mut self: Pin<&mut Self>,
446 cx: &mut Context<'_>,
447 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
448 match self.as_mut().body_receiver.poll_recv(cx) {
449 Poll::Pending => Poll::Pending,
450 Poll::Ready(Some(frame)) => Poll::Ready(Some(Ok(Frame::data(frame)))),
451
452 Poll::Ready(None) => {
454 if let Some(mut finish_receiver) = self.as_mut().finish_receiver.take() {
455 match Pin::new(&mut finish_receiver).poll(cx) {
456 Poll::Pending => {
457 self.as_mut().finish_receiver = Some(finish_receiver);
458 Poll::Pending
459 }
460 Poll::Ready(Ok(message)) => match message {
461 FinishMessage::Finished => Poll::Ready(None),
462 FinishMessage::Trailers(trailers) => {
463 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
464 }
465 FinishMessage::Abort => {
466 Poll::Ready(Some(Err(types::ErrorCode::HttpProtocolError)))
467 }
468 },
469 Poll::Ready(Err(RecvError { .. })) => Poll::Ready(None),
470 }
471 } else {
472 Poll::Ready(None)
473 }
474 }
475 }
476 }
477 }
478
479 let (body_sender, body_receiver) = mpsc::channel(buffer_chunks + 1);
481 let (finish_sender, finish_receiver) = oneshot::channel();
482 let body_impl = BodyImpl {
483 body_receiver,
484 finish_receiver: Some(finish_receiver),
485 }
486 .boxed_unsync();
487
488 let output_stream = BodyWriteStream::new(context, chunk_size, body_sender, written.clone());
489
490 (
491 Self {
492 body_output_stream: Some(Box::new(output_stream)),
493 context,
494 written,
495 finish_sender: Some(finish_sender),
496 },
497 body_impl,
498 )
499 }
500
501 pub fn take_output_stream(&mut self) -> Option<Box<dyn OutputStream>> {
503 self.body_output_stream.take()
504 }
505
506 pub fn finish(mut self, trailers: Option<FieldMap>) -> Result<(), types::ErrorCode> {
508 drop(self.body_output_stream);
511
512 let sender = self
513 .finish_sender
514 .take()
515 .expect("outgoing-body trailer_sender consumed by a non-owning function");
516
517 if let Some(w) = self.written {
518 let written = w.written();
519 if written != w.expected {
520 let _ = sender.send(FinishMessage::Abort);
521 return Err(self.context.as_body_size_error(written));
522 }
523 }
524
525 let message = if let Some(ts) = trailers {
526 FinishMessage::Trailers(ts.into())
527 } else {
528 FinishMessage::Finished
529 };
530
531 let _ = sender.send(message);
533
534 Ok(())
535 }
536
537 pub fn abort(mut self) {
539 drop(self.body_output_stream);
542
543 let sender = self
544 .finish_sender
545 .take()
546 .expect("outgoing-body trailer_sender consumed by a non-owning function");
547
548 let _ = sender.send(FinishMessage::Abort);
549 }
550}
551
552#[derive(Debug)]
554enum FinishMessage {
555 Finished,
556 Trailers(hyper::HeaderMap),
557 Abort,
558}
559
560#[derive(Clone, Copy, Debug, Eq, PartialEq)]
562pub enum StreamContext {
563 Request,
565 Response,
567}
568
569impl StreamContext {
570 pub fn as_body_size_error(&self, size: u64) -> types::ErrorCode {
572 match self {
573 StreamContext::Request => types::ErrorCode::HttpRequestBodySize(Some(size)),
574 StreamContext::Response => types::ErrorCode::HttpResponseBodySize(Some(size)),
575 }
576 }
577}
578
579#[derive(Debug)]
581struct BodyWriteStream {
582 context: StreamContext,
583 writer: mpsc::Sender<Bytes>,
584 write_budget: usize,
585 written: Option<WrittenState>,
586}
587
588impl BodyWriteStream {
589 fn new(
591 context: StreamContext,
592 write_budget: usize,
593 writer: mpsc::Sender<Bytes>,
594 written: Option<WrittenState>,
595 ) -> Self {
596 assert!(writer.max_capacity() >= 1);
598 BodyWriteStream {
599 context,
600 writer,
601 write_budget,
602 written,
603 }
604 }
605}
606
607#[async_trait::async_trait]
608impl OutputStream for BodyWriteStream {
609 fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
610 let len = bytes.len();
611 match self.writer.try_send(bytes) {
612 Ok(()) => {
615 if let Some(written) = self.written.as_ref() {
616 if !written.update(len) {
617 let total = written.written();
618 return Err(StreamError::LastOperationFailed(format_err!(
619 self.context.as_body_size_error(total)
620 )));
621 }
622 }
623
624 Ok(())
625 }
626
627 Err(mpsc::error::TrySendError::Full(_)) => {
631 Err(StreamError::Trap(format_err!("write exceeded budget")))
632 }
633
634 Err(mpsc::error::TrySendError::Closed(_)) => Err(StreamError::Closed),
636 }
637 }
638
639 fn flush(&mut self) -> Result<(), StreamError> {
640 if self.writer.is_closed() {
643 Err(StreamError::Closed)
644 } else {
645 Ok(())
646 }
647 }
648
649 fn check_write(&mut self) -> Result<usize, StreamError> {
650 if self.writer.is_closed() {
651 Err(StreamError::Closed)
652 } else if self.writer.capacity() == 0 {
653 Ok(0)
661 } else {
662 Ok(self.write_budget)
663 }
664 }
665}
666
667#[async_trait::async_trait]
668impl Pollable for BodyWriteStream {
669 async fn ready(&mut self) {
670 let _ = self.writer.reserve().await;
674 }
675}