1use crate::get_content_length;
2use crate::p3::bindings::http::types::ErrorCode;
3use crate::p3::body::{Body, BodyExt as _, GuestBody};
4use crate::p3::{WasiHttpCtxView, WasiHttpView};
5use bytes::Bytes;
6use core::time::Duration;
7use http::header::HOST;
8use http::uri::{Authority, PathAndQuery, Scheme};
9use http::{HeaderMap, HeaderValue, Method, Uri};
10use http_body_util::BodyExt as _;
11use http_body_util::combinators::UnsyncBoxBody;
12use std::sync::Arc;
13use tokio::sync::oneshot;
14use tracing::debug;
15use wasmtime::AsContextMut;
16
17#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
19pub struct RequestOptions {
20 pub connect_timeout: Option<Duration>,
22 pub first_byte_timeout: Option<Duration>,
24 pub between_bytes_timeout: Option<Duration>,
26}
27
28pub struct Request {
30 pub method: Method,
32 pub scheme: Option<Scheme>,
34 pub authority: Option<Authority>,
36 pub path_with_query: Option<PathAndQuery>,
38 pub headers: Arc<HeaderMap>,
40 pub options: Option<Arc<RequestOptions>>,
42 pub(crate) body: Body,
44}
45
46impl Request {
47 pub fn new(
54 method: Method,
55 scheme: Option<Scheme>,
56 authority: Option<Authority>,
57 path_with_query: Option<PathAndQuery>,
58 headers: impl Into<Arc<HeaderMap>>,
59 options: Option<Arc<RequestOptions>>,
60 body: impl Into<UnsyncBoxBody<Bytes, ErrorCode>>,
61 ) -> (
62 Self,
63 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
64 ) {
65 let (tx, rx) = oneshot::channel();
66 (
67 Self {
68 method,
69 scheme,
70 authority,
71 path_with_query,
72 headers: headers.into(),
73 options,
74 body: Body::Host {
75 body: body.into(),
76 result_tx: tx,
77 },
78 },
79 async {
80 let Ok(fut) = rx.await else { return Ok(()) };
81 Box::into_pin(fut).await
82 },
83 )
84 }
85
86 pub fn from_http<T>(
93 req: http::Request<T>,
94 ) -> (
95 Self,
96 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
97 )
98 where
99 T: http_body::Body<Data = Bytes> + Send + 'static,
100 T::Error: Into<ErrorCode>,
101 {
102 let (
103 http::request::Parts {
104 method,
105 uri,
106 headers,
107 ..
108 },
109 body,
110 ) = req.into_parts();
111 let http::uri::Parts {
112 scheme,
113 authority,
114 path_and_query,
115 ..
116 } = uri.into_parts();
117 Self::new(
118 method,
119 scheme,
120 authority,
121 path_and_query,
122 headers,
123 None,
124 body.map_err(Into::into).boxed_unsync(),
125 )
126 }
127
128 pub fn into_http<T: WasiHttpView + 'static>(
134 self,
135 store: impl AsContextMut<Data = T>,
136 fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
137 ) -> Result<
138 (
139 http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
140 Option<Arc<RequestOptions>>,
141 ),
142 ErrorCode,
143 > {
144 self.into_http_with_getter(store, fut, T::http)
145 }
146
147 pub fn into_http_with_getter<T: 'static>(
149 self,
150 mut store: impl AsContextMut<Data = T>,
151 fut: impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
152 getter: fn(&mut T) -> WasiHttpCtxView<'_>,
153 ) -> Result<
154 (
155 http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
156 Option<Arc<RequestOptions>>,
157 ),
158 ErrorCode,
159 > {
160 let Request {
161 method,
162 scheme,
163 authority,
164 path_with_query,
165 headers,
166 options,
167 body,
168 } = self;
169 let content_length = match get_content_length(&headers) {
171 Ok(content_length) => content_length,
172 Err(err) => {
173 body.drop(&mut store);
174 return Err(ErrorCode::InternalError(Some(format!("{err:#}"))));
175 }
176 };
177 let body = match body {
183 Body::Guest {
184 contents_rx,
185 trailers_rx,
186 result_tx,
187 } => GuestBody::new(
188 &mut store,
189 contents_rx,
190 trailers_rx,
191 result_tx,
192 fut,
193 content_length,
194 ErrorCode::HttpRequestBodySize,
195 getter,
196 )
197 .boxed_unsync(),
198 Body::Host { body, result_tx } => {
199 if let Some(limit) = content_length {
200 let (http_result_tx, http_result_rx) = oneshot::channel();
201 _ = result_tx.send(Box::new(async move {
202 if let Ok(err) = http_result_rx.await {
203 return Err(err);
204 };
205 fut.await
206 }));
207 body.with_content_length(limit, http_result_tx, ErrorCode::HttpRequestBodySize)
208 .boxed_unsync()
209 } else {
210 _ = result_tx.send(Box::new(fut));
211 body
212 }
213 }
214 };
215 let mut headers = Arc::unwrap_or_clone(headers);
216 let mut store = store.as_context_mut();
217 let WasiHttpCtxView { ctx, .. } = getter(store.data_mut());
218 if ctx.set_host_header() {
219 let host = if let Some(authority) = authority.as_ref() {
220 HeaderValue::try_from(authority.as_str())
221 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
222 } else {
223 HeaderValue::from_static("")
224 };
225 headers.insert(HOST, host);
226 }
227 let scheme = match scheme {
228 None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
229 Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme,
230 Some(..) => return Err(ErrorCode::HttpProtocolError),
231 };
232 let mut uri = Uri::builder().scheme(scheme);
233 if let Some(authority) = authority {
234 uri = uri.authority(authority)
235 };
236 if let Some(path_with_query) = path_with_query {
237 uri = uri.path_and_query(path_with_query)
238 };
239 let uri = uri.build().map_err(|err| {
240 debug!(?err, "failed to build request URI");
241 ErrorCode::HttpRequestUriInvalid
242 })?;
243 let mut req = http::Request::builder();
244 *req.headers_mut().unwrap() = headers;
245 let req = req
246 .method(method)
247 .uri(uri)
248 .body(body)
249 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
250 let (req, body) = req.into_parts();
251 Ok((http::Request::from_parts(req, body), options))
252 }
253}
254
255#[cfg(feature = "default-send-request")]
267pub async fn default_send_request(
268 mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
269 options: Option<RequestOptions>,
270) -> Result<
271 (
272 http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
273 impl Future<Output = Result<(), ErrorCode>> + Send,
274 ),
275 ErrorCode,
276> {
277 use core::future::poll_fn;
278 use core::pin::{Pin, pin};
279 use core::task::{Poll, ready};
280 use tokio::io::{AsyncRead, AsyncWrite};
281 use tokio::net::TcpStream;
282
283 trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
284 fn boxed(self) -> Box<dyn TokioStream>
285 where
286 Self: Sized,
287 {
288 Box::new(self)
289 }
290 }
291 impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
292
293 fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
294 ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
295 rcode: Some(rcode),
296 info_code: Some(info_code),
297 })
298 }
299
300 let uri = req.uri();
301 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
302 let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
303 let authority = if authority.port().is_some() {
304 authority.to_string()
305 } else {
306 let port = if use_tls { 443 } else { 80 };
307 format!("{authority}:{port}")
308 };
309
310 let connect_timeout = options
311 .and_then(
312 |RequestOptions {
313 connect_timeout, ..
314 }| connect_timeout,
315 )
316 .unwrap_or(Duration::from_secs(600));
317
318 let first_byte_timeout = options
319 .and_then(
320 |RequestOptions {
321 first_byte_timeout, ..
322 }| first_byte_timeout,
323 )
324 .unwrap_or(Duration::from_secs(600));
325
326 let between_bytes_timeout = options
327 .and_then(
328 |RequestOptions {
329 between_bytes_timeout,
330 ..
331 }| between_bytes_timeout,
332 )
333 .unwrap_or(Duration::from_secs(600));
334
335 let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
336 Ok(Ok(stream)) => stream,
337 Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
338 return Err(dns_error("address not available".to_string(), 0));
339 }
340 Ok(Err(err))
341 if err
342 .to_string()
343 .starts_with("failed to lookup address information") =>
344 {
345 return Err(dns_error("address not available".to_string(), 0));
346 }
347 Ok(Err(err)) => {
348 tracing::warn!(?err, "connection refused");
349 return Err(ErrorCode::ConnectionRefused);
350 }
351 Err(..) => return Err(ErrorCode::ConnectionTimeout),
352 };
353 let stream = if use_tls {
354 use rustls::pki_types::ServerName;
355
356 let root_cert_store = rustls::RootCertStore {
358 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
359 };
360 let config = rustls::ClientConfig::builder()
361 .with_root_certificates(root_cert_store)
362 .with_no_client_auth();
363 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
364 let mut parts = authority.split(":");
365 let host = parts.next().unwrap_or(&authority);
366 let domain = ServerName::try_from(host)
367 .map_err(|e| {
368 tracing::warn!("dns lookup error: {e:?}");
369 dns_error("invalid dns name".to_string(), 0)
370 })?
371 .to_owned();
372 let stream = connector.connect(domain, stream).await.map_err(|e| {
373 tracing::warn!("tls protocol error: {e:?}");
374 ErrorCode::TlsProtocolError
375 })?;
376 stream.boxed()
377 } else {
378 stream.boxed()
379 };
380 let (mut sender, conn) = tokio::time::timeout(
381 connect_timeout,
382 hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
384 )
385 .await
386 .map_err(|_| ErrorCode::ConnectionTimeout)?
387 .map_err(ErrorCode::from_hyper_request_error)?;
388
389 *req.uri_mut() = http::Uri::builder()
393 .path_and_query(
394 req.uri()
395 .path_and_query()
396 .map(|p| p.as_str())
397 .unwrap_or("/"),
398 )
399 .build()
400 .expect("comes from valid request");
401
402 let send = async move {
403 use core::task::Context;
404
405 struct IncomingResponseBody {
408 incoming: hyper::body::Incoming,
409 timeout: tokio::time::Interval,
410 }
411 impl http_body::Body for IncomingResponseBody {
412 type Data = <hyper::body::Incoming as http_body::Body>::Data;
413 type Error = ErrorCode;
414
415 fn poll_frame(
416 mut self: Pin<&mut Self>,
417 cx: &mut Context<'_>,
418 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
419 match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
420 Poll::Ready(None) => Poll::Ready(None),
421 Poll::Ready(Some(Err(err))) => {
422 Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
423 }
424 Poll::Ready(Some(Ok(frame))) => {
425 self.timeout.reset();
426 Poll::Ready(Some(Ok(frame)))
427 }
428 Poll::Pending => {
429 ready!(self.timeout.poll_tick(cx));
430 Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
431 }
432 }
433 }
434 fn is_end_stream(&self) -> bool {
435 self.incoming.is_end_stream()
436 }
437 fn size_hint(&self) -> http_body::SizeHint {
438 self.incoming.size_hint()
439 }
440 }
441
442 let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
443 .await
444 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
445 .map_err(ErrorCode::from_hyper_request_error)?;
446 let mut timeout = tokio::time::interval(between_bytes_timeout);
447 timeout.reset();
448 Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
449 };
450 let mut send = pin!(send);
451 let mut conn = Some(conn);
452 let res = poll_fn(|cx| match send.as_mut().poll(cx) {
454 Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
455 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
456 Poll::Pending => {
457 let Some(fut) = conn.as_mut() else {
459 return Poll::Pending;
461 };
462 let res = ready!(Pin::new(fut).poll(cx));
463 conn = None;
465 match res {
466 Ok(()) => send.as_mut().poll(cx),
468 Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
470 }
471 }
472 })
473 .await?;
474 Ok((res, async move {
475 let Some(conn) = conn.take() else {
476 return Ok(());
478 };
479 conn.await.map_err(ErrorCode::from_hyper_response_error)
480 }))
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use crate::p3::DefaultWasiHttpCtx;
487 use anyhow::Result;
488 use core::future::Future;
489 use core::pin::pin;
490 use core::str::FromStr;
491 use core::task::{Context, Poll, Waker};
492 use http_body_util::{BodyExt, Empty, Full};
493 use wasmtime::{Engine, Store};
494 use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
495
496 struct TestCtx {
497 table: ResourceTable,
498 wasi: WasiCtx,
499 http: DefaultWasiHttpCtx,
500 }
501
502 impl TestCtx {
503 fn new() -> Self {
504 Self {
505 table: ResourceTable::default(),
506 wasi: WasiCtxBuilder::new().build(),
507 http: DefaultWasiHttpCtx,
508 }
509 }
510 }
511
512 impl WasiView for TestCtx {
513 fn ctx(&mut self) -> WasiCtxView<'_> {
514 WasiCtxView {
515 ctx: &mut self.wasi,
516 table: &mut self.table,
517 }
518 }
519 }
520
521 impl WasiHttpView for TestCtx {
522 fn http(&mut self) -> WasiHttpCtxView<'_> {
523 WasiHttpCtxView {
524 ctx: &mut self.http,
525 table: &mut self.table,
526 }
527 }
528 }
529
530 #[tokio::test]
531 async fn test_request_into_http_schemes() -> Result<()> {
532 let schemes = vec![Some(Scheme::HTTP), Some(Scheme::HTTPS), None];
533 let engine = Engine::default();
534
535 for scheme in schemes {
536 let (req, fut) = Request::new(
537 Method::POST,
538 scheme.clone(),
539 Some(Authority::from_static("example.com")),
540 Some(PathAndQuery::from_static("/path?query=1")),
541 HeaderMap::new(),
542 None,
543 Full::new(Bytes::from_static(b"body"))
544 .map_err(|x| match x {})
545 .boxed_unsync(),
546 );
547 let mut store = Store::new(&engine, TestCtx::new());
548 let (http_req, options) = req.into_http(&mut store, async { Ok(()) }).unwrap();
549 assert_eq!(options, None);
550 assert_eq!(http_req.method(), Method::POST);
551 let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); assert_eq!(
553 http_req.uri(),
554 &http::Uri::from_str(&format!(
555 "{}://example.com/path?query=1",
556 expected_scheme.as_str()
557 ))
558 .unwrap()
559 );
560 let body_bytes = http_req.into_body().collect().await?;
561 assert_eq!(body_bytes.to_bytes(), b"body".as_slice());
562 let mut cx = Context::from_waker(Waker::noop());
563 let result = pin!(fut).poll(&mut cx);
564 assert!(matches!(result, Poll::Ready(Ok(()))));
565 }
566
567 Ok(())
568 }
569
570 #[tokio::test]
571 async fn test_request_into_http_uri_error() -> Result<()> {
572 let (req, fut) = Request::new(
573 Method::GET,
574 Some(Scheme::HTTP),
575 Some(Authority::from_static("example.com")),
576 None, HeaderMap::new(),
578 None,
579 Empty::new().map_err(|x| match x {}).boxed_unsync(),
580 );
581 let mut store = Store::new(&Engine::default(), TestCtx::new());
582 let result = req.into_http(&mut store, async {
583 Err(ErrorCode::InternalError(Some("uh oh".to_string())))
584 });
585 assert!(matches!(result, Err(ErrorCode::HttpRequestUriInvalid)));
586 let mut cx = Context::from_waker(Waker::noop());
587 let result = pin!(fut).poll(&mut cx);
588 assert!(matches!(
589 result,
590 Poll::Ready(Err(ErrorCode::InternalError(Some(_))))
591 ));
592
593 Ok(())
594 }
595}