wasmtime_wasi_http/p3/
request.rs

1use crate::p3::bindings::http::types::ErrorCode;
2use crate::p3::body::Body;
3use bytes::Bytes;
4use core::time::Duration;
5use http::uri::{Authority, PathAndQuery, Scheme};
6use http::{HeaderMap, Method};
7use http_body_util::BodyExt as _;
8use http_body_util::combinators::BoxBody;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// The concrete type behind a `wasi:http/types/request-options` resource.
13#[derive(Copy, Clone, Debug, Default)]
14pub struct RequestOptions {
15    /// How long to wait for a connection to be established.
16    pub connect_timeout: Option<Duration>,
17    /// How long to wait for the first byte of the response body.
18    pub first_byte_timeout: Option<Duration>,
19    /// How long to wait between frames of the response body.
20    pub between_bytes_timeout: Option<Duration>,
21}
22
23/// The concrete type behind a `wasi:http/types/request` resource.
24pub struct Request {
25    /// The method of the request.
26    pub method: Method,
27    /// The scheme of the request.
28    pub scheme: Option<Scheme>,
29    /// The authority of the request.
30    pub authority: Option<Authority>,
31    /// The path and query of the request.
32    pub path_with_query: Option<PathAndQuery>,
33    /// The request headers.
34    pub headers: Arc<HeaderMap>,
35    /// Request options.
36    pub options: Option<Arc<RequestOptions>>,
37    /// Request body.
38    pub(crate) body: Body,
39}
40
41impl Request {
42    /// Construct a new [Request]
43    ///
44    /// This returns a [Future] that the will be used to communicate
45    /// a request processing error, if any.
46    ///
47    /// Requests constructed this way will not perform any `Content-Length` validation.
48    pub fn new(
49        method: Method,
50        scheme: Option<Scheme>,
51        authority: Option<Authority>,
52        path_with_query: Option<PathAndQuery>,
53        headers: impl Into<Arc<HeaderMap>>,
54        options: Option<Arc<RequestOptions>>,
55        body: impl Into<BoxBody<Bytes, ErrorCode>>,
56    ) -> (
57        Self,
58        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
59    ) {
60        let (tx, rx) = oneshot::channel();
61        (
62            Self {
63                method,
64                scheme,
65                authority,
66                path_with_query,
67                headers: headers.into(),
68                options,
69                body: Body::Host {
70                    body: body.into(),
71                    result_tx: tx,
72                },
73            },
74            async {
75                let Ok(fut) = rx.await else { return Ok(()) };
76                Box::into_pin(fut).await
77            },
78        )
79    }
80
81    /// Construct a new [Request] from [http::Request].
82    ///
83    /// This returns a [Future] that will be used to communicate
84    /// a request processing error, if any.
85    ///
86    /// Requests constructed this way will not perform any `Content-Length` validation.
87    pub fn from_http<T>(
88        req: http::Request<T>,
89    ) -> (
90        Self,
91        impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
92    )
93    where
94        T: http_body::Body<Data = Bytes> + Send + Sync + 'static,
95        T::Error: Into<ErrorCode>,
96    {
97        let (
98            http::request::Parts {
99                method,
100                uri,
101                headers,
102                ..
103            },
104            body,
105        ) = req.into_parts();
106        let http::uri::Parts {
107            scheme,
108            authority,
109            path_and_query,
110            ..
111        } = uri.into_parts();
112        Self::new(
113            method,
114            scheme,
115            authority,
116            path_and_query,
117            headers,
118            None,
119            body.map_err(Into::into).boxed(),
120        )
121    }
122}
123
124/// The default implementation of how an outgoing request is sent.
125///
126/// This implementation is used by the `wasi:http/handler` interface
127/// default implementation.
128///
129/// The returned [Future] can be used to communicate
130/// a request processing error, if any, to the constructor of the request.
131/// For example, if the request was constructed via `wasi:http/types.request#new`,
132/// a result resolved from it will be forwarded to the guest on the future handle returned.
133///
134/// This function performs no `Content-Length` validation.
135#[cfg(feature = "default-send-request")]
136pub async fn default_send_request(
137    mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
138    options: Option<RequestOptions>,
139) -> Result<
140    (
141        http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
142        impl Future<Output = Result<(), ErrorCode>> + Send,
143    ),
144    ErrorCode,
145> {
146    use core::future::poll_fn;
147    use core::pin::{Pin, pin};
148    use core::task::{Poll, ready};
149    use tokio::io::{AsyncRead, AsyncWrite};
150    use tokio::net::TcpStream;
151
152    trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
153        fn boxed(self) -> Box<dyn TokioStream>
154        where
155            Self: Sized,
156        {
157            Box::new(self)
158        }
159    }
160    impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
161
162    fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
163        ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
164            rcode: Some(rcode),
165            info_code: Some(info_code),
166        })
167    }
168
169    let uri = req.uri();
170    let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
171    let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
172    let authority = if authority.port().is_some() {
173        authority.to_string()
174    } else {
175        let port = if use_tls { 443 } else { 80 };
176        format!("{authority}:{port}")
177    };
178
179    let connect_timeout = options
180        .and_then(
181            |RequestOptions {
182                 connect_timeout, ..
183             }| connect_timeout,
184        )
185        .unwrap_or(Duration::from_secs(600));
186
187    let first_byte_timeout = options
188        .and_then(
189            |RequestOptions {
190                 first_byte_timeout, ..
191             }| first_byte_timeout,
192        )
193        .unwrap_or(Duration::from_secs(600));
194
195    let between_bytes_timeout = options
196        .and_then(
197            |RequestOptions {
198                 between_bytes_timeout,
199                 ..
200             }| between_bytes_timeout,
201        )
202        .unwrap_or(Duration::from_secs(600));
203
204    let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
205        Ok(Ok(stream)) => stream,
206        Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
207            return Err(dns_error("address not available".to_string(), 0));
208        }
209        Ok(Err(err))
210            if err
211                .to_string()
212                .starts_with("failed to lookup address information") =>
213        {
214            return Err(dns_error("address not available".to_string(), 0));
215        }
216        Ok(Err(err)) => {
217            tracing::warn!(?err, "connection refused");
218            return Err(ErrorCode::ConnectionRefused);
219        }
220        Err(..) => return Err(ErrorCode::ConnectionTimeout),
221    };
222    let stream = if use_tls {
223        use rustls::pki_types::ServerName;
224
225        // derived from https://github.com/rustls/rustls/blob/main/examples/src/bin/simpleclient.rs
226        let root_cert_store = rustls::RootCertStore {
227            roots: webpki_roots::TLS_SERVER_ROOTS.into(),
228        };
229        let config = rustls::ClientConfig::builder()
230            .with_root_certificates(root_cert_store)
231            .with_no_client_auth();
232        let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
233        let mut parts = authority.split(":");
234        let host = parts.next().unwrap_or(&authority);
235        let domain = ServerName::try_from(host)
236            .map_err(|e| {
237                tracing::warn!("dns lookup error: {e:?}");
238                dns_error("invalid dns name".to_string(), 0)
239            })?
240            .to_owned();
241        let stream = connector.connect(domain, stream).await.map_err(|e| {
242            tracing::warn!("tls protocol error: {e:?}");
243            ErrorCode::TlsProtocolError
244        })?;
245        stream.boxed()
246    } else {
247        stream.boxed()
248    };
249    let (mut sender, conn) = tokio::time::timeout(
250        connect_timeout,
251        // TODO: we should plumb the builder through the http context, and use it here
252        hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
253    )
254    .await
255    .map_err(|_| ErrorCode::ConnectionTimeout)?
256    .map_err(ErrorCode::from_hyper_request_error)?;
257
258    // at this point, the request contains the scheme and the authority, but
259    // the http packet should only include those if addressing a proxy, so
260    // remove them here, since SendRequest::send_request does not do it for us
261    *req.uri_mut() = http::Uri::builder()
262        .path_and_query(
263            req.uri()
264                .path_and_query()
265                .map(|p| p.as_str())
266                .unwrap_or("/"),
267        )
268        .build()
269        .expect("comes from valid request");
270
271    let send = async move {
272        use core::task::Context;
273
274        /// Wrapper around [hyper::body::Incoming] used to
275        /// account for request option timeout configuration
276        struct IncomingResponseBody {
277            incoming: hyper::body::Incoming,
278            timeout: tokio::time::Interval,
279        }
280        impl http_body::Body for IncomingResponseBody {
281            type Data = <hyper::body::Incoming as http_body::Body>::Data;
282            type Error = ErrorCode;
283
284            fn poll_frame(
285                mut self: Pin<&mut Self>,
286                cx: &mut Context<'_>,
287            ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
288                match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
289                    Poll::Ready(None) => Poll::Ready(None),
290                    Poll::Ready(Some(Err(err))) => {
291                        Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
292                    }
293                    Poll::Ready(Some(Ok(frame))) => {
294                        self.timeout.reset();
295                        Poll::Ready(Some(Ok(frame)))
296                    }
297                    Poll::Pending => {
298                        ready!(self.timeout.poll_tick(cx));
299                        Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
300                    }
301                }
302            }
303            fn is_end_stream(&self) -> bool {
304                self.incoming.is_end_stream()
305            }
306            fn size_hint(&self) -> http_body::SizeHint {
307                self.incoming.size_hint()
308            }
309        }
310
311        let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
312            .await
313            .map_err(|_| ErrorCode::ConnectionReadTimeout)?
314            .map_err(ErrorCode::from_hyper_request_error)?;
315        let mut timeout = tokio::time::interval(between_bytes_timeout);
316        timeout.reset();
317        Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
318    };
319    let mut send = pin!(send);
320    let mut conn = Some(conn);
321    // Wait for response while driving connection I/O
322    let res = poll_fn(|cx| match send.as_mut().poll(cx) {
323        Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
324        Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
325        Poll::Pending => {
326            // Response is not ready, poll `hyper` connection to drive I/O if it has not completed yet
327            let Some(fut) = conn.as_mut() else {
328                // `hyper` connection already completed
329                return Poll::Pending;
330            };
331            let res = ready!(Pin::new(fut).poll(cx));
332            // `hyper` connection completed, record that to prevent repeated poll
333            conn = None;
334            match res {
335                // `hyper` connection has successfully completed, optimistically poll for response
336                Ok(()) => send.as_mut().poll(cx),
337                // `hyper` connection has failed, return the error
338                Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
339            }
340        }
341    })
342    .await?;
343    Ok((res, async move {
344        let Some(conn) = conn.take() else {
345            // `hyper` connection has already completed
346            return Ok(());
347        };
348        conn.await.map_err(ErrorCode::from_hyper_response_error)
349    }))
350}