1use crate::p3::bindings::http::types::ErrorCode;
2use crate::p3::body::{Body, BodyExt as _, GuestBody};
3use crate::p3::{HttpError, HttpResult, WasiHttpCtxView, WasiHttpView};
4use crate::{FieldMap, get_content_length};
5use bytes::Bytes;
6use core::time::Duration;
7use http::header::HOST;
8use http::uri::{Authority, PathAndQuery, Scheme};
9use http::{HeaderValue, Method, Uri};
10use http_body_util::BodyExt as _;
11use http_body_util::combinators::UnsyncBoxBody;
12use std::sync::Arc;
13use tokio::sync::oneshot;
14use tracing::debug;
15use wasmtime::AsContextMut;
16
17#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
19pub struct RequestOptions {
20 pub connect_timeout: Option<Duration>,
22 pub first_byte_timeout: Option<Duration>,
24 pub between_bytes_timeout: Option<Duration>,
26}
27
28pub struct Request {
30 pub method: Method,
32 pub scheme: Option<Scheme>,
34 pub authority: Option<Authority>,
36 pub path_with_query: Option<PathAndQuery>,
38 pub headers: FieldMap,
40 pub options: Option<Arc<RequestOptions>>,
42 pub(crate) body: Body,
44}
45
46impl Request {
47 pub fn new(
54 method: Method,
55 scheme: Option<Scheme>,
56 authority: Option<Authority>,
57 path_with_query: Option<PathAndQuery>,
58 headers: impl Into<FieldMap>,
59 options: Option<Arc<RequestOptions>>,
60 body: impl Into<UnsyncBoxBody<Bytes, ErrorCode>>,
61 ) -> (
62 Self,
63 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
64 ) {
65 let (tx, rx) = oneshot::channel();
66 (
67 Self {
68 method,
69 scheme,
70 authority,
71 path_with_query,
72 headers: headers.into(),
73 options,
74 body: Body::Host {
75 body: body.into(),
76 result_tx: tx,
77 },
78 },
79 async {
80 let Ok(fut) = rx.await else { return Ok(()) };
81 Box::into_pin(fut).await
82 },
83 )
84 }
85
86 pub fn from_http<T>(
93 req: http::Request<T>,
94 ) -> (
95 Self,
96 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
97 )
98 where
99 T: http_body::Body<Data = Bytes> + Send + 'static,
100 T::Error: Into<ErrorCode>,
101 {
102 let (
103 http::request::Parts {
104 method,
105 uri,
106 headers,
107 ..
108 },
109 body,
110 ) = req.into_parts();
111 let http::uri::Parts {
112 scheme,
113 authority,
114 path_and_query,
115 ..
116 } = uri.into_parts();
117 Self::new(
118 method,
119 scheme,
120 authority,
121 path_and_query,
122 FieldMap::new_immutable(headers),
123 None,
124 body.map_err(Into::into).boxed_unsync(),
125 )
126 }
127
128 pub fn into_http<T: WasiHttpView + 'static>(
134 self,
135 store: impl AsContextMut<Data = T>,
136 fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
137 ) -> HttpResult<(
138 http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
139 Option<Arc<RequestOptions>>,
140 )> {
141 self.into_http_with_getter(store, fut, T::http)
142 }
143
144 pub fn into_http_with_getter<T: 'static>(
146 self,
147 mut store: impl AsContextMut<Data = T>,
148 fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
149 getter: fn(&mut T) -> WasiHttpCtxView<'_>,
150 ) -> HttpResult<(
151 http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
152 Option<Arc<RequestOptions>>,
153 )> {
154 let Request {
155 method,
156 scheme,
157 authority,
158 path_with_query,
159 mut headers,
160 options,
161 body,
162 } = self;
163 let content_length = match get_content_length(&headers) {
165 Ok(content_length) => content_length,
166 Err(err) => {
167 body.drop(&mut store).map_err(HttpError::trap)?;
168 return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into());
169 }
170 };
171 let body = match body {
177 Body::Guest {
178 contents_rx,
179 trailers_rx,
180 result_tx,
181 } => GuestBody::new(
182 &mut store,
183 contents_rx,
184 trailers_rx,
185 result_tx,
186 fut,
187 content_length,
188 ErrorCode::HttpRequestBodySize,
189 getter,
190 )
191 .map_err(HttpError::trap)?
192 .boxed_unsync(),
193 Body::Host { body, result_tx } => {
194 if let Some(limit) = content_length {
195 let (http_result_tx, http_result_rx) = oneshot::channel();
196 _ = result_tx.send(Box::new(async move {
197 if let Ok(err) = http_result_rx.await {
198 return Err(err);
199 };
200 fut.await
201 }));
202 body.with_content_length(limit, http_result_tx, ErrorCode::HttpRequestBodySize)
203 .boxed_unsync()
204 } else {
205 _ = result_tx.send(Box::new(fut));
206 body
207 }
208 }
209 };
210 let mut store = store.as_context_mut();
211 let WasiHttpCtxView { hooks, ctx, .. } = getter(store.data_mut());
212 headers.set_mutable(ctx.field_size_limit);
213 if hooks.set_host_header() {
214 let host = if let Some(authority) = authority.as_ref() {
215 HeaderValue::try_from(authority.as_str())
216 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
217 } else {
218 HeaderValue::from_static("")
219 };
220 headers.append(HOST, host).map_err(HttpError::trap)?;
221 }
222 let scheme = match scheme {
223 None => hooks.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
224 Some(scheme) if hooks.is_supported_scheme(&scheme) => scheme,
225 Some(..) => return Err(ErrorCode::HttpProtocolError.into()),
226 };
227 let mut uri = Uri::builder().scheme(scheme);
228 if let Some(authority) = authority {
229 uri = uri.authority(authority)
230 };
231 if let Some(path_with_query) = path_with_query {
232 uri = uri.path_and_query(path_with_query)
233 };
234 let uri = uri.build().map_err(|err| {
235 debug!(?err, "failed to build request URI");
236 ErrorCode::HttpRequestUriInvalid
237 })?;
238 let mut req = http::Request::builder();
239 *req.headers_mut().unwrap() = headers.into();
240 let req = req
241 .method(method)
242 .uri(uri)
243 .body(body)
244 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
245 let (req, body) = req.into_parts();
246 Ok((http::Request::from_parts(req, body), options))
247 }
248}
249
250#[cfg(feature = "default-send-request")]
262pub async fn default_send_request(
263 mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
264 options: Option<RequestOptions>,
265) -> Result<
266 (
267 http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
268 impl Future<Output = Result<(), ErrorCode>> + Send,
269 ),
270 ErrorCode,
271> {
272 use core::future::poll_fn;
273 use core::pin::{Pin, pin};
274 use core::task::{Poll, ready};
275 use tokio::io::{AsyncRead, AsyncWrite};
276 use tokio::net::TcpStream;
277
278 trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
279 fn boxed(self) -> Box<dyn TokioStream>
280 where
281 Self: Sized,
282 {
283 Box::new(self)
284 }
285 }
286 impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
287
288 fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
289 ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
290 rcode: Some(rcode),
291 info_code: Some(info_code),
292 })
293 }
294
295 let uri = req.uri();
296 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
297 let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
298 let authority = if authority.port().is_some() {
299 authority.to_string()
300 } else {
301 let port = if use_tls { 443 } else { 80 };
302 format!("{authority}:{port}")
303 };
304
305 let connect_timeout = options
306 .and_then(
307 |RequestOptions {
308 connect_timeout, ..
309 }| connect_timeout,
310 )
311 .unwrap_or(Duration::from_secs(600));
312
313 let first_byte_timeout = options
314 .and_then(
315 |RequestOptions {
316 first_byte_timeout, ..
317 }| first_byte_timeout,
318 )
319 .unwrap_or(Duration::from_secs(600));
320
321 let between_bytes_timeout = options
322 .and_then(
323 |RequestOptions {
324 between_bytes_timeout,
325 ..
326 }| between_bytes_timeout,
327 )
328 .unwrap_or(Duration::from_secs(600));
329
330 let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
331 Ok(Ok(stream)) => stream,
332 Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
333 return Err(dns_error("address not available".to_string(), 0));
334 }
335 Ok(Err(err))
336 if err
337 .to_string()
338 .starts_with("failed to lookup address information") =>
339 {
340 return Err(dns_error("address not available".to_string(), 0));
341 }
342 Ok(Err(err)) => {
343 tracing::warn!(?err, "connection refused");
344 return Err(ErrorCode::ConnectionRefused);
345 }
346 Err(..) => return Err(ErrorCode::ConnectionTimeout),
347 };
348 let stream = if use_tls {
349 use rustls::pki_types::ServerName;
350
351 let root_cert_store = rustls::RootCertStore {
353 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
354 };
355 let config = rustls::ClientConfig::builder()
356 .with_root_certificates(root_cert_store)
357 .with_no_client_auth();
358 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
359 let mut parts = authority.split(":");
360 let host = parts.next().unwrap_or(&authority);
361 let domain = ServerName::try_from(host)
362 .map_err(|e| {
363 tracing::warn!("dns lookup error: {e:?}");
364 dns_error("invalid dns name".to_string(), 0)
365 })?
366 .to_owned();
367 let stream = connector.connect(domain, stream).await.map_err(|e| {
368 tracing::warn!("tls protocol error: {e:?}");
369 ErrorCode::TlsProtocolError
370 })?;
371 stream.boxed()
372 } else {
373 stream.boxed()
374 };
375 let (mut sender, conn) = tokio::time::timeout(
376 connect_timeout,
377 hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
379 )
380 .await
381 .map_err(|_| ErrorCode::ConnectionTimeout)?
382 .map_err(ErrorCode::from_hyper_request_error)?;
383
384 *req.uri_mut() = http::Uri::builder()
388 .path_and_query(
389 req.uri()
390 .path_and_query()
391 .map(|p| p.as_str())
392 .unwrap_or("/"),
393 )
394 .build()
395 .expect("comes from valid request");
396
397 let send = async move {
398 use core::task::Context;
399
400 struct IncomingResponseBody {
403 incoming: hyper::body::Incoming,
404 timeout: tokio::time::Interval,
405 }
406 impl http_body::Body for IncomingResponseBody {
407 type Data = <hyper::body::Incoming as http_body::Body>::Data;
408 type Error = ErrorCode;
409
410 fn poll_frame(
411 mut self: Pin<&mut Self>,
412 cx: &mut Context<'_>,
413 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
414 match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
415 Poll::Ready(None) => Poll::Ready(None),
416 Poll::Ready(Some(Err(err))) => {
417 Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
418 }
419 Poll::Ready(Some(Ok(frame))) => {
420 self.timeout.reset();
421 Poll::Ready(Some(Ok(frame)))
422 }
423 Poll::Pending => {
424 ready!(self.timeout.poll_tick(cx));
425 Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
426 }
427 }
428 }
429 fn is_end_stream(&self) -> bool {
430 self.incoming.is_end_stream()
431 }
432 fn size_hint(&self) -> http_body::SizeHint {
433 self.incoming.size_hint()
434 }
435 }
436
437 let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
438 .await
439 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
440 .map_err(ErrorCode::from_hyper_request_error)?;
441 let mut timeout = tokio::time::interval(between_bytes_timeout);
442 timeout.reset();
443 Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
444 };
445 let mut send = pin!(send);
446 let mut conn = Some(conn);
447 let res = poll_fn(|cx| match send.as_mut().poll(cx) {
449 Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
450 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
451 Poll::Pending => {
452 let Some(fut) = conn.as_mut() else {
454 return Poll::Pending;
456 };
457 let res = ready!(Pin::new(fut).poll(cx));
458 conn = None;
460 match res {
461 Ok(()) => send.as_mut().poll(cx),
463 Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
465 }
466 }
467 })
468 .await?;
469 Ok((res, async move {
470 let Some(conn) = conn.take() else {
471 return Ok(());
473 };
474 conn.await.map_err(ErrorCode::from_hyper_response_error)
475 }))
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::WasiHttpCtx;
482 use core::future::Future;
483 use core::pin::pin;
484 use core::str::FromStr;
485 use core::task::{Context, Poll, Waker};
486 use http_body_util::{BodyExt, Empty, Full};
487 use wasmtime::Result;
488 use wasmtime::{Engine, Store};
489 use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
490
491 struct TestCtx {
492 table: ResourceTable,
493 wasi: WasiCtx,
494 http: WasiHttpCtx,
495 }
496
497 impl TestCtx {
498 fn new() -> Self {
499 Self {
500 table: ResourceTable::default(),
501 wasi: WasiCtxBuilder::new().build(),
502 http: Default::default(),
503 }
504 }
505 }
506
507 impl WasiView for TestCtx {
508 fn ctx(&mut self) -> WasiCtxView<'_> {
509 WasiCtxView {
510 ctx: &mut self.wasi,
511 table: &mut self.table,
512 }
513 }
514 }
515
516 impl WasiHttpView for TestCtx {
517 fn http(&mut self) -> WasiHttpCtxView<'_> {
518 WasiHttpCtxView {
519 ctx: &mut self.http,
520 table: &mut self.table,
521 hooks: crate::p3::default_hooks(),
522 }
523 }
524 }
525
526 #[tokio::test]
527 async fn test_request_into_http_schemes() -> Result<()> {
528 let schemes = vec![Some(Scheme::HTTP), Some(Scheme::HTTPS), None];
529 let engine = Engine::default();
530
531 for scheme in schemes {
532 let (req, fut) = Request::new(
533 Method::POST,
534 scheme.clone(),
535 Some(Authority::from_static("example.com")),
536 Some(PathAndQuery::from_static("/path?query=1")),
537 FieldMap::default(),
538 None,
539 Full::new(Bytes::from_static(b"body"))
540 .map_err(|x| match x {})
541 .boxed_unsync(),
542 );
543 let mut store = Store::new(&engine, TestCtx::new());
544 let (http_req, options) = req.into_http(&mut store, async { Ok(()) }).unwrap();
545 assert_eq!(options, None);
546 assert_eq!(http_req.method(), Method::POST);
547 let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); assert_eq!(
549 http_req.uri(),
550 &http::Uri::from_str(&format!(
551 "{}://example.com/path?query=1",
552 expected_scheme.as_str()
553 ))
554 .unwrap()
555 );
556 let body_bytes = http_req.into_body().collect().await?;
557 assert_eq!(body_bytes.to_bytes(), b"body".as_slice());
558 let mut cx = Context::from_waker(Waker::noop());
559 let result = pin!(fut).poll(&mut cx);
560 assert!(matches!(result, Poll::Ready(Ok(()))));
561 }
562
563 Ok(())
564 }
565
566 #[tokio::test]
567 async fn test_request_into_http_uri_error() -> Result<()> {
568 let (req, fut) = Request::new(
569 Method::GET,
570 Some(Scheme::HTTP),
571 Some(Authority::from_static("example.com")),
572 None, FieldMap::default(),
574 None,
575 Empty::new().map_err(|x| match x {}).boxed_unsync(),
576 );
577 let mut store = Store::new(&Engine::default(), TestCtx::new());
578 let result = req
579 .into_http(&mut store, async {
580 Err(ErrorCode::InternalError(Some("uh oh".to_string())))
581 })
582 .unwrap_err();
583 assert!(matches!(
584 result.downcast()?,
585 ErrorCode::HttpRequestUriInvalid,
586 ));
587 let mut cx = Context::from_waker(Waker::noop());
588 let result = pin!(fut).poll(&mut cx);
589 assert!(matches!(
590 result,
591 Poll::Ready(Err(ErrorCode::InternalError(Some(_))))
592 ));
593
594 Ok(())
595 }
596}