Skip to main content

wasmtime_wasi_http/p3/
request.rs

1use crate::p3::bindings::http::types::ErrorCode;
2use crate::p3::body::{Body, BodyExt as _, GuestBody};
3use crate::p3::{HttpError, HttpResult, WasiHttpCtxView, WasiHttpView};
4use crate::{FieldMap, get_content_length};
5use bytes::Bytes;
6use core::time::Duration;
7use http::header::HOST;
8use http::uri::{Authority, PathAndQuery, Scheme};
9use http::{HeaderValue, Method, Uri};
10use http_body_util::BodyExt as _;
11use http_body_util::combinators::UnsyncBoxBody;
12use std::sync::Arc;
13use tokio::sync::oneshot;
14use tracing::debug;
15use wasmtime::AsContextMut;
16
17/// The concrete type behind a `wasi:http/types.request-options` resource.
18#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
19pub struct RequestOptions {
20    /// How long to wait for a connection to be established.
21    pub connect_timeout: Option<Duration>,
22    /// How long to wait for the first byte of the response body.
23    pub first_byte_timeout: Option<Duration>,
24    /// How long to wait between frames of the response body.
25    pub between_bytes_timeout: Option<Duration>,
26}
27
28/// The concrete type behind a `wasi:http/types.request` resource.
29pub struct Request {
30    /// The method of the request.
31    pub method: Method,
32    /// The scheme of the request.
33    pub scheme: Option<Scheme>,
34    /// The authority of the request.
35    pub authority: Option<Authority>,
36    /// The path and query of the request.
37    pub path_with_query: Option<PathAndQuery>,
38    /// The request headers.
39    pub headers: FieldMap,
40    /// Request options.
41    pub options: Option<Arc<RequestOptions>>,
42    /// Request body.
43    pub(crate) body: Body,
44}
45
46impl Request {
47    /// Construct a new [Request]
48    ///
49    /// This returns a [Future] that the will be used to communicate
50    /// a request processing error, if any.
51    ///
52    /// Requests constructed this way will not perform any `Content-Length` validation.
53    pub fn new(
54        method: Method,
55        scheme: Option<Scheme>,
56        authority: Option<Authority>,
57        path_with_query: Option<PathAndQuery>,
58        headers: impl Into<FieldMap>,
59        options: Option<Arc<RequestOptions>>,
60        body: impl Into<UnsyncBoxBody<Bytes, ErrorCode>>,
61    ) -> (
62        Self,
63        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
64    ) {
65        let (tx, rx) = oneshot::channel();
66        (
67            Self {
68                method,
69                scheme,
70                authority,
71                path_with_query,
72                headers: headers.into(),
73                options,
74                body: Body::Host {
75                    body: body.into(),
76                    result_tx: tx,
77                },
78            },
79            async {
80                let Ok(fut) = rx.await else { return Ok(()) };
81                Box::into_pin(fut).await
82            },
83        )
84    }
85
86    /// Construct a new [Request] from [http::Request].
87    ///
88    /// This returns a [Future] that will be used to communicate
89    /// a request processing error, if any.
90    ///
91    /// Requests constructed this way will not perform any `Content-Length` validation.
92    pub fn from_http<T>(
93        req: http::Request<T>,
94    ) -> (
95        Self,
96        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
97    )
98    where
99        T: http_body::Body<Data = Bytes> + Send + 'static,
100        T::Error: Into<ErrorCode>,
101    {
102        let (
103            http::request::Parts {
104                method,
105                uri,
106                headers,
107                ..
108            },
109            body,
110        ) = req.into_parts();
111        let http::uri::Parts {
112            scheme,
113            authority,
114            path_and_query,
115            ..
116        } = uri.into_parts();
117        Self::new(
118            method,
119            scheme,
120            authority,
121            path_and_query,
122            FieldMap::new_immutable(headers),
123            None,
124            body.map_err(Into::into).boxed_unsync(),
125        )
126    }
127
128    /// Convert this [`Request`] into an [`http::Request<UnsyncBoxBody<Bytes, ErrorCode>>`].
129    ///
130    /// The specified future `fut` can be used to communicate a request processing
131    /// error, if any, back to the caller (e.g., if this request was constructed
132    /// through `wasi:http/types.request#new`).
133    pub fn into_http<T: WasiHttpView + 'static>(
134        self,
135        store: impl AsContextMut<Data = T>,
136        fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
137    ) -> HttpResult<(
138        http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
139        Option<Arc<RequestOptions>>,
140    )> {
141        self.into_http_with_getter(store, fut, T::http)
142    }
143
144    /// Like [`Self::into_http`], but uses a custom getter for obtaining the [`WasiHttpCtxView`].
145    pub fn into_http_with_getter<T: 'static>(
146        self,
147        mut store: impl AsContextMut<Data = T>,
148        fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
149        getter: fn(&mut T) -> WasiHttpCtxView<'_>,
150    ) -> HttpResult<(
151        http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
152        Option<Arc<RequestOptions>>,
153    )> {
154        let Request {
155            method,
156            scheme,
157            authority,
158            path_with_query,
159            mut headers,
160            options,
161            body,
162        } = self;
163        // `Content-Length` header value is validated in `fields` implementation
164        let content_length = match get_content_length(&headers) {
165            Ok(content_length) => content_length,
166            Err(err) => {
167                body.drop(&mut store).map_err(HttpError::trap)?;
168                return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into());
169            }
170        };
171        // This match must appear before any potential errors handled with '?'
172        // (or errors have to explicitly be addressed and drop the body, as above),
173        // as otherwise the Body::Guest resources will not be cleaned up when dropped.
174        // see: https://github.com/bytecodealliance/wasmtime/pull/11440#discussion_r2326139381
175        // for additional context.
176        let body = match body {
177            Body::Guest {
178                contents_rx,
179                trailers_rx,
180                result_tx,
181            } => GuestBody::new(
182                &mut store,
183                contents_rx,
184                trailers_rx,
185                result_tx,
186                fut,
187                content_length,
188                ErrorCode::HttpRequestBodySize,
189                getter,
190            )
191            .map_err(HttpError::trap)?
192            .boxed_unsync(),
193            Body::Host { body, result_tx } => {
194                if let Some(limit) = content_length {
195                    let (http_result_tx, http_result_rx) = oneshot::channel();
196                    _ = result_tx.send(Box::new(async move {
197                        if let Ok(err) = http_result_rx.await {
198                            return Err(err);
199                        };
200                        fut.await
201                    }));
202                    body.with_content_length(limit, http_result_tx, ErrorCode::HttpRequestBodySize)
203                        .boxed_unsync()
204                } else {
205                    _ = result_tx.send(Box::new(fut));
206                    body
207                }
208            }
209        };
210        let mut store = store.as_context_mut();
211        let WasiHttpCtxView { hooks, ctx, .. } = getter(store.data_mut());
212        headers.set_mutable(ctx.field_size_limit);
213        if hooks.set_host_header() {
214            let host = if let Some(authority) = authority.as_ref() {
215                HeaderValue::try_from(authority.as_str())
216                    .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
217            } else {
218                HeaderValue::from_static("")
219            };
220            headers.append(HOST, host).map_err(HttpError::trap)?;
221        }
222        let scheme = match scheme {
223            None => hooks.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
224            Some(scheme) if hooks.is_supported_scheme(&scheme) => scheme,
225            Some(..) => return Err(ErrorCode::HttpProtocolError.into()),
226        };
227        let mut uri = Uri::builder().scheme(scheme);
228        if let Some(authority) = authority {
229            uri = uri.authority(authority)
230        };
231        if let Some(path_with_query) = path_with_query {
232            uri = uri.path_and_query(path_with_query)
233        };
234        let uri = uri.build().map_err(|err| {
235            debug!(?err, "failed to build request URI");
236            ErrorCode::HttpRequestUriInvalid
237        })?;
238        let mut req = http::Request::builder();
239        *req.headers_mut().unwrap() = headers.into();
240        let req = req
241            .method(method)
242            .uri(uri)
243            .body(body)
244            .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
245        let (req, body) = req.into_parts();
246        Ok((http::Request::from_parts(req, body), options))
247    }
248}
249
250/// The default implementation of how an outgoing request is sent.
251///
252/// This implementation is used by the `wasi:http/handler` interface
253/// default implementation.
254///
255/// The returned [Future] can be used to communicate
256/// a request processing error, if any, to the constructor of the request.
257/// For example, if the request was constructed via `wasi:http/types.request#new`,
258/// a result resolved from it will be forwarded to the guest on the future handle returned.
259///
260/// This function performs no `Content-Length` validation.
261#[cfg(feature = "default-send-request")]
262pub async fn default_send_request(
263    mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
264    options: Option<RequestOptions>,
265) -> Result<
266    (
267        http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
268        impl Future<Output = Result<(), ErrorCode>> + Send,
269    ),
270    ErrorCode,
271> {
272    use core::future::poll_fn;
273    use core::pin::{Pin, pin};
274    use core::task::{Poll, ready};
275    use tokio::io::{AsyncRead, AsyncWrite};
276    use tokio::net::TcpStream;
277
278    trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
279        fn boxed(self) -> Box<dyn TokioStream>
280        where
281            Self: Sized,
282        {
283            Box::new(self)
284        }
285    }
286    impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
287
288    fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
289        ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
290            rcode: Some(rcode),
291            info_code: Some(info_code),
292        })
293    }
294
295    let uri = req.uri();
296    let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
297    let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
298    let authority = if authority.port().is_some() {
299        authority.to_string()
300    } else {
301        let port = if use_tls { 443 } else { 80 };
302        format!("{authority}:{port}")
303    };
304
305    let connect_timeout = options
306        .and_then(
307            |RequestOptions {
308                 connect_timeout, ..
309             }| connect_timeout,
310        )
311        .unwrap_or(Duration::from_secs(600));
312
313    let first_byte_timeout = options
314        .and_then(
315            |RequestOptions {
316                 first_byte_timeout, ..
317             }| first_byte_timeout,
318        )
319        .unwrap_or(Duration::from_secs(600));
320
321    let between_bytes_timeout = options
322        .and_then(
323            |RequestOptions {
324                 between_bytes_timeout,
325                 ..
326             }| between_bytes_timeout,
327        )
328        .unwrap_or(Duration::from_secs(600));
329
330    let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
331        Ok(Ok(stream)) => stream,
332        Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
333            return Err(dns_error("address not available".to_string(), 0));
334        }
335        Ok(Err(err))
336            if err
337                .to_string()
338                .starts_with("failed to lookup address information") =>
339        {
340            return Err(dns_error("address not available".to_string(), 0));
341        }
342        Ok(Err(err)) => {
343            tracing::warn!(?err, "connection refused");
344            return Err(ErrorCode::ConnectionRefused);
345        }
346        Err(..) => return Err(ErrorCode::ConnectionTimeout),
347    };
348    let stream = if use_tls {
349        use rustls::pki_types::ServerName;
350
351        // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs
352        let root_cert_store = rustls::RootCertStore {
353            roots: webpki_roots::TLS_SERVER_ROOTS.into(),
354        };
355        let config = rustls::ClientConfig::builder()
356            .with_root_certificates(root_cert_store)
357            .with_no_client_auth();
358        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
359        let mut parts = authority.split(":");
360        let host = parts.next().unwrap_or(&authority);
361        let domain = ServerName::try_from(host)
362            .map_err(|e| {
363                tracing::warn!("dns lookup error: {e:?}");
364                dns_error("invalid dns name".to_string(), 0)
365            })?
366            .to_owned();
367        let stream = connector.connect(domain, stream).await.map_err(|e| {
368            tracing::warn!("tls protocol error: {e:?}");
369            ErrorCode::TlsProtocolError
370        })?;
371        stream.boxed()
372    } else {
373        stream.boxed()
374    };
375    let (mut sender, conn) = tokio::time::timeout(
376        connect_timeout,
377        // TODO: we should plumb the builder through the http context, and use it here
378        hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
379    )
380    .await
381    .map_err(|_| ErrorCode::ConnectionTimeout)?
382    .map_err(ErrorCode::from_hyper_request_error)?;
383
384    // at this point, the request contains the scheme and the authority, but
385    // the http packet should only include those if addressing a proxy, so
386    // remove them here, since SendRequest::send_request does not do it for us
387    *req.uri_mut() = http::Uri::builder()
388        .path_and_query(
389            req.uri()
390                .path_and_query()
391                .map(|p| p.as_str())
392                .unwrap_or("/"),
393        )
394        .build()
395        .expect("comes from valid request");
396
397    let send = async move {
398        use core::task::Context;
399
400        /// Wrapper around [hyper::body::Incoming] used to
401        /// account for request option timeout configuration
402        struct IncomingResponseBody {
403            incoming: hyper::body::Incoming,
404            timeout: tokio::time::Interval,
405        }
406        impl http_body::Body for IncomingResponseBody {
407            type Data = <hyper::body::Incoming as http_body::Body>::Data;
408            type Error = ErrorCode;
409
410            fn poll_frame(
411                mut self: Pin<&mut Self>,
412                cx: &mut Context<'_>,
413            ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
414                match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
415                    Poll::Ready(None) => Poll::Ready(None),
416                    Poll::Ready(Some(Err(err))) => {
417                        Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
418                    }
419                    Poll::Ready(Some(Ok(frame))) => {
420                        self.timeout.reset();
421                        Poll::Ready(Some(Ok(frame)))
422                    }
423                    Poll::Pending => {
424                        ready!(self.timeout.poll_tick(cx));
425                        Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
426                    }
427                }
428            }
429            fn is_end_stream(&self) -> bool {
430                self.incoming.is_end_stream()
431            }
432            fn size_hint(&self) -> http_body::SizeHint {
433                self.incoming.size_hint()
434            }
435        }
436
437        let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
438            .await
439            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
440            .map_err(ErrorCode::from_hyper_request_error)?;
441        let mut timeout = tokio::time::interval(between_bytes_timeout);
442        timeout.reset();
443        Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
444    };
445    let mut send = pin!(send);
446    let mut conn = Some(conn);
447    // Wait for response while driving connection I/O
448    let res = poll_fn(|cx| match send.as_mut().poll(cx) {
449        Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
450        Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
451        Poll::Pending => {
452            // Response is not ready, poll `hyper` connection to drive I/O if it has not completed yet
453            let Some(fut) = conn.as_mut() else {
454                // `hyper` connection already completed
455                return Poll::Pending;
456            };
457            let res = ready!(Pin::new(fut).poll(cx));
458            // `hyper` connection completed, record that to prevent repeated poll
459            conn = None;
460            match res {
461                // `hyper` connection has successfully completed, optimistically poll for response
462                Ok(()) => send.as_mut().poll(cx),
463                // `hyper` connection has failed, return the error
464                Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
465            }
466        }
467    })
468    .await?;
469    Ok((res, async move {
470        let Some(conn) = conn.take() else {
471            // `hyper` connection has already completed
472            return Ok(());
473        };
474        conn.await.map_err(ErrorCode::from_hyper_response_error)
475    }))
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481    use crate::WasiHttpCtx;
482    use core::future::Future;
483    use core::pin::pin;
484    use core::str::FromStr;
485    use core::task::{Context, Poll, Waker};
486    use http_body_util::{BodyExt, Empty, Full};
487    use wasmtime::Result;
488    use wasmtime::{Engine, Store};
489    use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
490
491    struct TestCtx {
492        table: ResourceTable,
493        wasi: WasiCtx,
494        http: WasiHttpCtx,
495    }
496
497    impl TestCtx {
498        fn new() -> Self {
499            Self {
500                table: ResourceTable::default(),
501                wasi: WasiCtxBuilder::new().build(),
502                http: Default::default(),
503            }
504        }
505    }
506
507    impl WasiView for TestCtx {
508        fn ctx(&mut self) -> WasiCtxView<'_> {
509            WasiCtxView {
510                ctx: &mut self.wasi,
511                table: &mut self.table,
512            }
513        }
514    }
515
516    impl WasiHttpView for TestCtx {
517        fn http(&mut self) -> WasiHttpCtxView<'_> {
518            WasiHttpCtxView {
519                ctx: &mut self.http,
520                table: &mut self.table,
521                hooks: crate::p3::default_hooks(),
522            }
523        }
524    }
525
526    #[tokio::test]
527    async fn test_request_into_http_schemes() -> Result<()> {
528        let schemes = vec![Some(Scheme::HTTP), Some(Scheme::HTTPS), None];
529        let engine = Engine::default();
530
531        for scheme in schemes {
532            let (req, fut) = Request::new(
533                Method::POST,
534                scheme.clone(),
535                Some(Authority::from_static("example.com")),
536                Some(PathAndQuery::from_static("/path?query=1")),
537                FieldMap::default(),
538                None,
539                Full::new(Bytes::from_static(b"body"))
540                    .map_err(|x| match x {})
541                    .boxed_unsync(),
542            );
543            let mut store = Store::new(&engine, TestCtx::new());
544            let (http_req, options) = req.into_http(&mut store, async { Ok(()) }).unwrap();
545            assert_eq!(options, None);
546            assert_eq!(http_req.method(), Method::POST);
547            let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); // default scheme
548            assert_eq!(
549                http_req.uri(),
550                &http::Uri::from_str(&format!(
551                    "{}://example.com/path?query=1",
552                    expected_scheme.as_str()
553                ))
554                .unwrap()
555            );
556            let body_bytes = http_req.into_body().collect().await?;
557            assert_eq!(body_bytes.to_bytes(), b"body".as_slice());
558            let mut cx = Context::from_waker(Waker::noop());
559            let result = pin!(fut).poll(&mut cx);
560            assert!(matches!(result, Poll::Ready(Ok(()))));
561        }
562
563        Ok(())
564    }
565
566    #[tokio::test]
567    async fn test_request_into_http_uri_error() -> Result<()> {
568        let (req, fut) = Request::new(
569            Method::GET,
570            Some(Scheme::HTTP),
571            Some(Authority::from_static("example.com")),
572            None, // <-- should fail, must be Some(_) when authority is set
573            FieldMap::default(),
574            None,
575            Empty::new().map_err(|x| match x {}).boxed_unsync(),
576        );
577        let mut store = Store::new(&Engine::default(), TestCtx::new());
578        let result = req
579            .into_http(&mut store, async {
580                Err(ErrorCode::InternalError(Some("uh oh".to_string())))
581            })
582            .unwrap_err();
583        assert!(matches!(
584            result.downcast()?,
585            ErrorCode::HttpRequestUriInvalid,
586        ));
587        let mut cx = Context::from_waker(Waker::noop());
588        let result = pin!(fut).poll(&mut cx);
589        assert!(matches!(
590            result,
591            Poll::Ready(Err(ErrorCode::InternalError(Some(_))))
592        ));
593
594        Ok(())
595    }
596}