wasmtime_wasi_http/p3/
request.rs

1use crate::get_content_length;
2use crate::p3::bindings::http::types::ErrorCode;
3use crate::p3::body::{Body, BodyExt as _, GuestBody};
4use crate::p3::{WasiHttpCtxView, WasiHttpView};
5use bytes::Bytes;
6use core::time::Duration;
7use http::header::HOST;
8use http::uri::{Authority, PathAndQuery, Scheme};
9use http::{HeaderMap, HeaderValue, Method, Uri};
10use http_body_util::BodyExt as _;
11use http_body_util::combinators::UnsyncBoxBody;
12use std::sync::Arc;
13use tokio::sync::oneshot;
14use tracing::debug;
15use wasmtime::AsContextMut;
16
17/// The concrete type behind a `wasi:http/types.request-options` resource.
18#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
19pub struct RequestOptions {
20    /// How long to wait for a connection to be established.
21    pub connect_timeout: Option<Duration>,
22    /// How long to wait for the first byte of the response body.
23    pub first_byte_timeout: Option<Duration>,
24    /// How long to wait between frames of the response body.
25    pub between_bytes_timeout: Option<Duration>,
26}
27
28/// The concrete type behind a `wasi:http/types.request` resource.
29pub struct Request {
30    /// The method of the request.
31    pub method: Method,
32    /// The scheme of the request.
33    pub scheme: Option<Scheme>,
34    /// The authority of the request.
35    pub authority: Option<Authority>,
36    /// The path and query of the request.
37    pub path_with_query: Option<PathAndQuery>,
38    /// The request headers.
39    pub headers: Arc<HeaderMap>,
40    /// Request options.
41    pub options: Option<Arc<RequestOptions>>,
42    /// Request body.
43    pub(crate) body: Body,
44}
45
46impl Request {
47    /// Construct a new [Request]
48    ///
49    /// This returns a [Future] that the will be used to communicate
50    /// a request processing error, if any.
51    ///
52    /// Requests constructed this way will not perform any `Content-Length` validation.
53    pub fn new(
54        method: Method,
55        scheme: Option<Scheme>,
56        authority: Option<Authority>,
57        path_with_query: Option<PathAndQuery>,
58        headers: impl Into<Arc<HeaderMap>>,
59        options: Option<Arc<RequestOptions>>,
60        body: impl Into<UnsyncBoxBody<Bytes, ErrorCode>>,
61    ) -> (
62        Self,
63        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
64    ) {
65        let (tx, rx) = oneshot::channel();
66        (
67            Self {
68                method,
69                scheme,
70                authority,
71                path_with_query,
72                headers: headers.into(),
73                options,
74                body: Body::Host {
75                    body: body.into(),
76                    result_tx: tx,
77                },
78            },
79            async {
80                let Ok(fut) = rx.await else { return Ok(()) };
81                Box::into_pin(fut).await
82            },
83        )
84    }
85
86    /// Construct a new [Request] from [http::Request].
87    ///
88    /// This returns a [Future] that will be used to communicate
89    /// a request processing error, if any.
90    ///
91    /// Requests constructed this way will not perform any `Content-Length` validation.
92    pub fn from_http<T>(
93        req: http::Request<T>,
94    ) -> (
95        Self,
96        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
97    )
98    where
99        T: http_body::Body<Data = Bytes> + Send + 'static,
100        T::Error: Into<ErrorCode>,
101    {
102        let (
103            http::request::Parts {
104                method,
105                uri,
106                headers,
107                ..
108            },
109            body,
110        ) = req.into_parts();
111        let http::uri::Parts {
112            scheme,
113            authority,
114            path_and_query,
115            ..
116        } = uri.into_parts();
117        Self::new(
118            method,
119            scheme,
120            authority,
121            path_and_query,
122            headers,
123            None,
124            body.map_err(Into::into).boxed_unsync(),
125        )
126    }
127
128    /// Convert this [`Request`] into an [`http::Request<UnsyncBoxBody<Bytes, ErrorCode>>`].
129    ///
130    /// The specified future `fut` can be used to communicate a request processing
131    /// error, if any, back to the caller (e.g., if this request was constructed
132    /// through `wasi:http/types.request#new`).
133    pub fn into_http<T: WasiHttpView + 'static>(
134        self,
135        store: impl AsContextMut<Data = T>,
136        fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
137    ) -> Result<
138        (
139            http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
140            Option<Arc<RequestOptions>>,
141        ),
142        ErrorCode,
143    > {
144        self.into_http_with_getter(store, fut, T::http)
145    }
146
147    /// Like [`Self::into_http`], but uses a custom getter for obtaining the [`WasiHttpCtxView`].
148    pub fn into_http_with_getter<T: 'static>(
149        self,
150        mut store: impl AsContextMut<Data = T>,
151        fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
152        getter: fn(&mut T) -> WasiHttpCtxView<'_>,
153    ) -> Result<
154        (
155            http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
156            Option<Arc<RequestOptions>>,
157        ),
158        ErrorCode,
159    > {
160        let Request {
161            method,
162            scheme,
163            authority,
164            path_with_query,
165            headers,
166            options,
167            body,
168        } = self;
169        // `Content-Length` header value is validated in `fields` implementation
170        let content_length = match get_content_length(&headers) {
171            Ok(content_length) => content_length,
172            Err(err) => {
173                body.drop(&mut store);
174                return Err(ErrorCode::InternalError(Some(format!("{err:#}"))));
175            }
176        };
177        // This match must appear before any potential errors handled with '?'
178        // (or errors have to explicitly be addressed and drop the body, as above),
179        // as otherwise the Body::Guest resources will not be cleaned up when dropped.
180        // see: https://github.com/bytecodealliance/wasmtime/pull/11440#discussion_r2326139381
181        // for additional context.
182        let body = match body {
183            Body::Guest {
184                contents_rx,
185                trailers_rx,
186                result_tx,
187            } => GuestBody::new(
188                &mut store,
189                contents_rx,
190                trailers_rx,
191                result_tx,
192                fut,
193                content_length,
194                ErrorCode::HttpRequestBodySize,
195                getter,
196            )
197            .boxed_unsync(),
198            Body::Host { body, result_tx } => {
199                if let Some(limit) = content_length {
200                    let (http_result_tx, http_result_rx) = oneshot::channel();
201                    _ = result_tx.send(Box::new(async move {
202                        if let Ok(err) = http_result_rx.await {
203                            return Err(err);
204                        };
205                        fut.await
206                    }));
207                    body.with_content_length(limit, http_result_tx, ErrorCode::HttpRequestBodySize)
208                        .boxed_unsync()
209                } else {
210                    _ = result_tx.send(Box::new(fut));
211                    body
212                }
213            }
214        };
215        let mut headers = Arc::unwrap_or_clone(headers);
216        let mut store = store.as_context_mut();
217        let WasiHttpCtxView { ctx, .. } = getter(store.data_mut());
218        if ctx.set_host_header() {
219            let host = if let Some(authority) = authority.as_ref() {
220                HeaderValue::try_from(authority.as_str())
221                    .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
222            } else {
223                HeaderValue::from_static("")
224            };
225            headers.insert(HOST, host);
226        }
227        let scheme = match scheme {
228            None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
229            Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme,
230            Some(..) => return Err(ErrorCode::HttpProtocolError),
231        };
232        let mut uri = Uri::builder().scheme(scheme);
233        if let Some(authority) = authority {
234            uri = uri.authority(authority)
235        };
236        if let Some(path_with_query) = path_with_query {
237            uri = uri.path_and_query(path_with_query)
238        };
239        let uri = uri.build().map_err(|err| {
240            debug!(?err, "failed to build request URI");
241            ErrorCode::HttpRequestUriInvalid
242        })?;
243        let mut req = http::Request::builder();
244        *req.headers_mut().unwrap() = headers;
245        let req = req
246            .method(method)
247            .uri(uri)
248            .body(body)
249            .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
250        let (req, body) = req.into_parts();
251        Ok((http::Request::from_parts(req, body), options))
252    }
253}
254
255/// The default implementation of how an outgoing request is sent.
256///
257/// This implementation is used by the `wasi:http/handler` interface
258/// default implementation.
259///
260/// The returned [Future] can be used to communicate
261/// a request processing error, if any, to the constructor of the request.
262/// For example, if the request was constructed via `wasi:http/types.request#new`,
263/// a result resolved from it will be forwarded to the guest on the future handle returned.
264///
265/// This function performs no `Content-Length` validation.
266#[cfg(feature = "default-send-request")]
267pub async fn default_send_request(
268    mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
269    options: Option<RequestOptions>,
270) -> Result<
271    (
272        http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
273        impl Future<Output = Result<(), ErrorCode>> + Send,
274    ),
275    ErrorCode,
276> {
277    use core::future::poll_fn;
278    use core::pin::{Pin, pin};
279    use core::task::{Poll, ready};
280    use tokio::io::{AsyncRead, AsyncWrite};
281    use tokio::net::TcpStream;
282
283    trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
284        fn boxed(self) -> Box<dyn TokioStream>
285        where
286            Self: Sized,
287        {
288            Box::new(self)
289        }
290    }
291    impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
292
293    fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
294        ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
295            rcode: Some(rcode),
296            info_code: Some(info_code),
297        })
298    }
299
300    let uri = req.uri();
301    let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
302    let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
303    let authority = if authority.port().is_some() {
304        authority.to_string()
305    } else {
306        let port = if use_tls { 443 } else { 80 };
307        format!("{authority}:{port}")
308    };
309
310    let connect_timeout = options
311        .and_then(
312            |RequestOptions {
313                 connect_timeout, ..
314             }| connect_timeout,
315        )
316        .unwrap_or(Duration::from_secs(600));
317
318    let first_byte_timeout = options
319        .and_then(
320            |RequestOptions {
321                 first_byte_timeout, ..
322             }| first_byte_timeout,
323        )
324        .unwrap_or(Duration::from_secs(600));
325
326    let between_bytes_timeout = options
327        .and_then(
328            |RequestOptions {
329                 between_bytes_timeout,
330                 ..
331             }| between_bytes_timeout,
332        )
333        .unwrap_or(Duration::from_secs(600));
334
335    let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
336        Ok(Ok(stream)) => stream,
337        Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
338            return Err(dns_error("address not available".to_string(), 0));
339        }
340        Ok(Err(err))
341            if err
342                .to_string()
343                .starts_with("failed to lookup address information") =>
344        {
345            return Err(dns_error("address not available".to_string(), 0));
346        }
347        Ok(Err(err)) => {
348            tracing::warn!(?err, "connection refused");
349            return Err(ErrorCode::ConnectionRefused);
350        }
351        Err(..) => return Err(ErrorCode::ConnectionTimeout),
352    };
353    let stream = if use_tls {
354        use rustls::pki_types::ServerName;
355
356        // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs
357        let root_cert_store = rustls::RootCertStore {
358            roots: webpki_roots::TLS_SERVER_ROOTS.into(),
359        };
360        let config = rustls::ClientConfig::builder()
361            .with_root_certificates(root_cert_store)
362            .with_no_client_auth();
363        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
364        let mut parts = authority.split(":");
365        let host = parts.next().unwrap_or(&authority);
366        let domain = ServerName::try_from(host)
367            .map_err(|e| {
368                tracing::warn!("dns lookup error: {e:?}");
369                dns_error("invalid dns name".to_string(), 0)
370            })?
371            .to_owned();
372        let stream = connector.connect(domain, stream).await.map_err(|e| {
373            tracing::warn!("tls protocol error: {e:?}");
374            ErrorCode::TlsProtocolError
375        })?;
376        stream.boxed()
377    } else {
378        stream.boxed()
379    };
380    let (mut sender, conn) = tokio::time::timeout(
381        connect_timeout,
382        // TODO: we should plumb the builder through the http context, and use it here
383        hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
384    )
385    .await
386    .map_err(|_| ErrorCode::ConnectionTimeout)?
387    .map_err(ErrorCode::from_hyper_request_error)?;
388
389    // at this point, the request contains the scheme and the authority, but
390    // the http packet should only include those if addressing a proxy, so
391    // remove them here, since SendRequest::send_request does not do it for us
392    *req.uri_mut() = http::Uri::builder()
393        .path_and_query(
394            req.uri()
395                .path_and_query()
396                .map(|p| p.as_str())
397                .unwrap_or("/"),
398        )
399        .build()
400        .expect("comes from valid request");
401
402    let send = async move {
403        use core::task::Context;
404
405        /// Wrapper around [hyper::body::Incoming] used to
406        /// account for request option timeout configuration
407        struct IncomingResponseBody {
408            incoming: hyper::body::Incoming,
409            timeout: tokio::time::Interval,
410        }
411        impl http_body::Body for IncomingResponseBody {
412            type Data = <hyper::body::Incoming as http_body::Body>::Data;
413            type Error = ErrorCode;
414
415            fn poll_frame(
416                mut self: Pin<&mut Self>,
417                cx: &mut Context<'_>,
418            ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
419                match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
420                    Poll::Ready(None) => Poll::Ready(None),
421                    Poll::Ready(Some(Err(err))) => {
422                        Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
423                    }
424                    Poll::Ready(Some(Ok(frame))) => {
425                        self.timeout.reset();
426                        Poll::Ready(Some(Ok(frame)))
427                    }
428                    Poll::Pending => {
429                        ready!(self.timeout.poll_tick(cx));
430                        Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
431                    }
432                }
433            }
434            fn is_end_stream(&self) -> bool {
435                self.incoming.is_end_stream()
436            }
437            fn size_hint(&self) -> http_body::SizeHint {
438                self.incoming.size_hint()
439            }
440        }
441
442        let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
443            .await
444            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
445            .map_err(ErrorCode::from_hyper_request_error)?;
446        let mut timeout = tokio::time::interval(between_bytes_timeout);
447        timeout.reset();
448        Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
449    };
450    let mut send = pin!(send);
451    let mut conn = Some(conn);
452    // Wait for response while driving connection I/O
453    let res = poll_fn(|cx| match send.as_mut().poll(cx) {
454        Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
455        Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
456        Poll::Pending => {
457            // Response is not ready, poll `hyper` connection to drive I/O if it has not completed yet
458            let Some(fut) = conn.as_mut() else {
459                // `hyper` connection already completed
460                return Poll::Pending;
461            };
462            let res = ready!(Pin::new(fut).poll(cx));
463            // `hyper` connection completed, record that to prevent repeated poll
464            conn = None;
465            match res {
466                // `hyper` connection has successfully completed, optimistically poll for response
467                Ok(()) => send.as_mut().poll(cx),
468                // `hyper` connection has failed, return the error
469                Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
470            }
471        }
472    })
473    .await?;
474    Ok((res, async move {
475        let Some(conn) = conn.take() else {
476            // `hyper` connection has already completed
477            return Ok(());
478        };
479        conn.await.map_err(ErrorCode::from_hyper_response_error)
480    }))
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use crate::p3::DefaultWasiHttpCtx;
487    use anyhow::Result;
488    use core::future::Future;
489    use core::pin::pin;
490    use core::str::FromStr;
491    use core::task::{Context, Poll, Waker};
492    use http_body_util::{BodyExt, Empty, Full};
493    use wasmtime::{Engine, Store};
494    use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
495
496    struct TestCtx {
497        table: ResourceTable,
498        wasi: WasiCtx,
499        http: DefaultWasiHttpCtx,
500    }
501
502    impl TestCtx {
503        fn new() -> Self {
504            Self {
505                table: ResourceTable::default(),
506                wasi: WasiCtxBuilder::new().build(),
507                http: DefaultWasiHttpCtx,
508            }
509        }
510    }
511
512    impl WasiView for TestCtx {
513        fn ctx(&mut self) -> WasiCtxView<'_> {
514            WasiCtxView {
515                ctx: &mut self.wasi,
516                table: &mut self.table,
517            }
518        }
519    }
520
521    impl WasiHttpView for TestCtx {
522        fn http(&mut self) -> WasiHttpCtxView<'_> {
523            WasiHttpCtxView {
524                ctx: &mut self.http,
525                table: &mut self.table,
526            }
527        }
528    }
529
530    #[tokio::test]
531    async fn test_request_into_http_schemes() -> Result<()> {
532        let schemes = vec![Some(Scheme::HTTP), Some(Scheme::HTTPS), None];
533        let engine = Engine::default();
534
535        for scheme in schemes {
536            let (req, fut) = Request::new(
537                Method::POST,
538                scheme.clone(),
539                Some(Authority::from_static("example.com")),
540                Some(PathAndQuery::from_static("/path?query=1")),
541                HeaderMap::new(),
542                None,
543                Full::new(Bytes::from_static(b"body"))
544                    .map_err(|x| match x {})
545                    .boxed_unsync(),
546            );
547            let mut store = Store::new(&engine, TestCtx::new());
548            let (http_req, options) = req.into_http(&mut store, async { Ok(()) }).unwrap();
549            assert_eq!(options, None);
550            assert_eq!(http_req.method(), Method::POST);
551            let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); // default scheme
552            assert_eq!(
553                http_req.uri(),
554                &http::Uri::from_str(&format!(
555                    "{}://example.com/path?query=1",
556                    expected_scheme.as_str()
557                ))
558                .unwrap()
559            );
560            let body_bytes = http_req.into_body().collect().await?;
561            assert_eq!(body_bytes.to_bytes(), b"body".as_slice());
562            let mut cx = Context::from_waker(Waker::noop());
563            let result = pin!(fut).poll(&mut cx);
564            assert!(matches!(result, Poll::Ready(Ok(()))));
565        }
566
567        Ok(())
568    }
569
570    #[tokio::test]
571    async fn test_request_into_http_uri_error() -> Result<()> {
572        let (req, fut) = Request::new(
573            Method::GET,
574            Some(Scheme::HTTP),
575            Some(Authority::from_static("example.com")),
576            None, // <-- should fail, must be Some(_) when authority is set
577            HeaderMap::new(),
578            None,
579            Empty::new().map_err(|x| match x {}).boxed_unsync(),
580        );
581        let mut store = Store::new(&Engine::default(), TestCtx::new());
582        let result = req.into_http(&mut store, async {
583            Err(ErrorCode::InternalError(Some("uh oh".to_string())))
584        });
585        assert!(matches!(result, Err(ErrorCode::HttpRequestUriInvalid)));
586        let mut cx = Context::from_waker(Waker::noop());
587        let result = pin!(fut).poll(&mut cx);
588        assert!(matches!(
589            result,
590            Poll::Ready(Err(ErrorCode::InternalError(Some(_))))
591        ));
592
593        Ok(())
594    }
595}