1use crate::{
5 bindings::http::types::{self, ErrorCode, Method, Scheme},
6 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
7};
8use bytes::Bytes;
9use http_body_util::BodyExt;
10use hyper::body::Body;
11use hyper::header::HeaderName;
12use std::any::Any;
13use std::time::Duration;
14use wasmtime::bail;
15use wasmtime::component::{Resource, ResourceTable};
16use wasmtime_wasi::p2::Pollable;
17use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
18
19#[cfg(feature = "default-send-request")]
20use {
21 crate::io::TokioIo,
22 crate::{error::dns_error, hyper_request_error},
23 tokio::net::TcpStream,
24 tokio::time::timeout,
25};
26
27#[derive(Debug)]
29pub struct WasiHttpCtx {
30 _priv: (),
31}
32
33impl WasiHttpCtx {
34 pub fn new() -> Self {
36 Self { _priv: () }
37 }
38}
39
40pub trait WasiHttpView {
82 fn ctx(&mut self) -> &mut WasiHttpCtx;
84
85 fn table(&mut self) -> &mut ResourceTable;
87
88 fn new_incoming_request<B>(
90 &mut self,
91 scheme: Scheme,
92 req: hyper::Request<B>,
93 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
94 where
95 B: Body<Data = Bytes> + Send + 'static,
96 B::Error: Into<ErrorCode>,
97 Self: Sized,
98 {
99 let (parts, body) = req.into_parts();
100 let body = body.map_err(Into::into).boxed_unsync();
101 let body = HostIncomingBody::new(
102 body,
103 std::time::Duration::from_millis(600 * 1000),
105 );
106 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
107 Ok(self.table().push(incoming_req)?)
108 }
109
110 fn new_response_outparam(
112 &mut self,
113 result: tokio::sync::oneshot::Sender<
114 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
115 >,
116 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
117 let id = self.table().push(HostResponseOutparam { result })?;
118 Ok(id)
119 }
120
121 #[cfg(feature = "default-send-request")]
123 fn send_request(
124 &mut self,
125 request: hyper::Request<HyperOutgoingBody>,
126 config: OutgoingRequestConfig,
127 ) -> crate::HttpResult<HostFutureIncomingResponse> {
128 Ok(default_send_request(request, config))
129 }
130
131 #[cfg(not(feature = "default-send-request"))]
133 fn send_request(
134 &mut self,
135 request: hyper::Request<HyperOutgoingBody>,
136 config: OutgoingRequestConfig,
137 ) -> crate::HttpResult<HostFutureIncomingResponse>;
138
139 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
141 DEFAULT_FORBIDDEN_HEADERS.contains(name)
142 }
143
144 fn outgoing_body_buffer_chunks(&mut self) -> usize {
148 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
149 }
150
151 fn outgoing_body_chunk_size(&mut self) -> usize {
154 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
155 }
156}
157
158pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
160pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
162
163impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
164 fn ctx(&mut self) -> &mut WasiHttpCtx {
165 T::ctx(self)
166 }
167
168 fn table(&mut self) -> &mut ResourceTable {
169 T::table(self)
170 }
171
172 fn new_response_outparam(
173 &mut self,
174 result: tokio::sync::oneshot::Sender<
175 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
176 >,
177 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
178 T::new_response_outparam(self, result)
179 }
180
181 fn send_request(
182 &mut self,
183 request: hyper::Request<HyperOutgoingBody>,
184 config: OutgoingRequestConfig,
185 ) -> crate::HttpResult<HostFutureIncomingResponse> {
186 T::send_request(self, request, config)
187 }
188
189 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
190 T::is_forbidden_header(self, name)
191 }
192
193 fn outgoing_body_buffer_chunks(&mut self) -> usize {
194 T::outgoing_body_buffer_chunks(self)
195 }
196
197 fn outgoing_body_chunk_size(&mut self) -> usize {
198 T::outgoing_body_chunk_size(self)
199 }
200}
201
202impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
203 fn ctx(&mut self) -> &mut WasiHttpCtx {
204 T::ctx(self)
205 }
206
207 fn table(&mut self) -> &mut ResourceTable {
208 T::table(self)
209 }
210
211 fn new_response_outparam(
212 &mut self,
213 result: tokio::sync::oneshot::Sender<
214 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
215 >,
216 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
217 T::new_response_outparam(self, result)
218 }
219
220 fn send_request(
221 &mut self,
222 request: hyper::Request<HyperOutgoingBody>,
223 config: OutgoingRequestConfig,
224 ) -> crate::HttpResult<HostFutureIncomingResponse> {
225 T::send_request(self, request, config)
226 }
227
228 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
229 T::is_forbidden_header(self, name)
230 }
231
232 fn outgoing_body_buffer_chunks(&mut self) -> usize {
233 T::outgoing_body_buffer_chunks(self)
234 }
235
236 fn outgoing_body_chunk_size(&mut self) -> usize {
237 T::outgoing_body_chunk_size(self)
238 }
239}
240
241#[repr(transparent)]
254pub struct WasiHttpImpl<T>(pub T);
255
256impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
257 fn ctx(&mut self) -> &mut WasiHttpCtx {
258 self.0.ctx()
259 }
260
261 fn table(&mut self) -> &mut ResourceTable {
262 self.0.table()
263 }
264
265 fn new_response_outparam(
266 &mut self,
267 result: tokio::sync::oneshot::Sender<
268 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
269 >,
270 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
271 self.0.new_response_outparam(result)
272 }
273
274 fn send_request(
275 &mut self,
276 request: hyper::Request<HyperOutgoingBody>,
277 config: OutgoingRequestConfig,
278 ) -> crate::HttpResult<HostFutureIncomingResponse> {
279 self.0.send_request(request, config)
280 }
281
282 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
283 self.0.is_forbidden_header(name)
284 }
285
286 fn outgoing_body_buffer_chunks(&mut self) -> usize {
287 self.0.outgoing_body_buffer_chunks()
288 }
289
290 fn outgoing_body_chunk_size(&mut self) -> usize {
291 self.0.outgoing_body_chunk_size()
292 }
293}
294
295pub const DEFAULT_FORBIDDEN_HEADERS: [http::header::HeaderName; 9] = [
298 hyper::header::CONNECTION,
299 HeaderName::from_static("keep-alive"),
300 hyper::header::PROXY_AUTHENTICATE,
301 hyper::header::PROXY_AUTHORIZATION,
302 HeaderName::from_static("proxy-connection"),
303 hyper::header::TRANSFER_ENCODING,
304 hyper::header::UPGRADE,
305 hyper::header::HOST,
306 HeaderName::from_static("http2-settings"),
307];
308
309pub(crate) fn remove_forbidden_headers(
311 view: &mut dyn WasiHttpView,
312 headers: &mut hyper::HeaderMap,
313) {
314 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
315 if view.is_forbidden_header(name) {
316 Some(name.clone())
317 } else {
318 None
319 }
320 }));
321
322 for name in forbidden_keys {
323 headers.remove(name);
324 }
325}
326
327pub struct OutgoingRequestConfig {
329 pub use_tls: bool,
331 pub connect_timeout: Duration,
333 pub first_byte_timeout: Duration,
335 pub between_bytes_timeout: Duration,
337}
338
339#[cfg(feature = "default-send-request")]
344pub fn default_send_request(
345 request: hyper::Request<HyperOutgoingBody>,
346 config: OutgoingRequestConfig,
347) -> HostFutureIncomingResponse {
348 let handle = wasmtime_wasi::runtime::spawn(async move {
349 Ok(default_send_request_handler(request, config).await)
350 });
351 HostFutureIncomingResponse::pending(handle)
352}
353
354#[cfg(feature = "default-send-request")]
359pub async fn default_send_request_handler(
360 mut request: hyper::Request<HyperOutgoingBody>,
361 OutgoingRequestConfig {
362 use_tls,
363 connect_timeout,
364 first_byte_timeout,
365 between_bytes_timeout,
366 }: OutgoingRequestConfig,
367) -> Result<IncomingResponse, types::ErrorCode> {
368 let authority = if let Some(authority) = request.uri().authority() {
369 if authority.port().is_some() {
370 authority.to_string()
371 } else {
372 let port = if use_tls { 443 } else { 80 };
373 format!("{}:{port}", authority.to_string())
374 }
375 } else {
376 return Err(types::ErrorCode::HttpRequestUriInvalid);
377 };
378 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
379 .await
380 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
381 .map_err(|e| match e.kind() {
382 std::io::ErrorKind::AddrNotAvailable => {
383 dns_error("address not available".to_string(), 0)
384 }
385
386 _ => {
387 if e.to_string()
388 .starts_with("failed to lookup address information")
389 {
390 dns_error("address not available".to_string(), 0)
391 } else {
392 types::ErrorCode::ConnectionRefused
393 }
394 }
395 })?;
396
397 let (mut sender, worker) = if use_tls {
398 use rustls::pki_types::ServerName;
399
400 let root_cert_store = rustls::RootCertStore {
402 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
403 };
404 let config = rustls::ClientConfig::builder()
405 .with_root_certificates(root_cert_store)
406 .with_no_client_auth();
407 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
408 let mut parts = authority.split(":");
409 let host = parts.next().unwrap_or(&authority);
410 let domain = ServerName::try_from(host)
411 .map_err(|e| {
412 tracing::warn!("dns lookup error: {e:?}");
413 dns_error("invalid dns name".to_string(), 0)
414 })?
415 .to_owned();
416 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
417 tracing::warn!("tls protocol error: {e:?}");
418 types::ErrorCode::TlsProtocolError
419 })?;
420 let stream = TokioIo::new(stream);
421
422 let (sender, conn) = timeout(
423 connect_timeout,
424 hyper::client::conn::http1::handshake(stream),
425 )
426 .await
427 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
428 .map_err(hyper_request_error)?;
429
430 let worker = wasmtime_wasi::runtime::spawn(async move {
431 match conn.await {
432 Ok(()) => {}
433 Err(e) => tracing::warn!("dropping error {e}"),
436 }
437 });
438
439 (sender, worker)
440 } else {
441 let tcp_stream = TokioIo::new(tcp_stream);
442 let (sender, conn) = timeout(
443 connect_timeout,
444 hyper::client::conn::http1::handshake(tcp_stream),
446 )
447 .await
448 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
449 .map_err(hyper_request_error)?;
450
451 let worker = wasmtime_wasi::runtime::spawn(async move {
452 match conn.await {
453 Ok(()) => {}
454 Err(e) => tracing::warn!("dropping error {e}"),
456 }
457 });
458
459 (sender, worker)
460 };
461
462 *request.uri_mut() = http::Uri::builder()
466 .path_and_query(
467 request
468 .uri()
469 .path_and_query()
470 .map(|p| p.as_str())
471 .unwrap_or("/"),
472 )
473 .build()
474 .expect("comes from valid request");
475
476 let resp = timeout(first_byte_timeout, sender.send_request(request))
477 .await
478 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
479 .map_err(hyper_request_error)?
480 .map(|body| body.map_err(hyper_request_error).boxed_unsync());
481
482 Ok(IncomingResponse {
483 resp,
484 worker: Some(worker),
485 between_bytes_timeout,
486 })
487}
488
489impl From<http::Method> for types::Method {
490 fn from(method: http::Method) -> Self {
491 if method == http::Method::GET {
492 types::Method::Get
493 } else if method == hyper::Method::HEAD {
494 types::Method::Head
495 } else if method == hyper::Method::POST {
496 types::Method::Post
497 } else if method == hyper::Method::PUT {
498 types::Method::Put
499 } else if method == hyper::Method::DELETE {
500 types::Method::Delete
501 } else if method == hyper::Method::CONNECT {
502 types::Method::Connect
503 } else if method == hyper::Method::OPTIONS {
504 types::Method::Options
505 } else if method == hyper::Method::TRACE {
506 types::Method::Trace
507 } else if method == hyper::Method::PATCH {
508 types::Method::Patch
509 } else {
510 types::Method::Other(method.to_string())
511 }
512 }
513}
514
515impl TryInto<http::Method> for types::Method {
516 type Error = http::method::InvalidMethod;
517
518 fn try_into(self) -> Result<http::Method, Self::Error> {
519 match self {
520 Method::Get => Ok(http::Method::GET),
521 Method::Head => Ok(http::Method::HEAD),
522 Method::Post => Ok(http::Method::POST),
523 Method::Put => Ok(http::Method::PUT),
524 Method::Delete => Ok(http::Method::DELETE),
525 Method::Connect => Ok(http::Method::CONNECT),
526 Method::Options => Ok(http::Method::OPTIONS),
527 Method::Trace => Ok(http::Method::TRACE),
528 Method::Patch => Ok(http::Method::PATCH),
529 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
530 }
531 }
532}
533
534#[derive(Debug)]
536pub struct HostIncomingRequest {
537 pub(crate) parts: http::request::Parts,
538 pub(crate) scheme: Scheme,
539 pub(crate) authority: String,
540 pub body: Option<HostIncomingBody>,
542}
543
544impl HostIncomingRequest {
545 pub fn new(
547 view: &mut dyn WasiHttpView,
548 mut parts: http::request::Parts,
549 scheme: Scheme,
550 body: Option<HostIncomingBody>,
551 ) -> wasmtime::Result<Self> {
552 let authority = match parts.uri.authority() {
553 Some(authority) => authority.to_string(),
554 None => match parts.headers.get(http::header::HOST) {
555 Some(host) => host.to_str()?.to_string(),
556 None => bail!("invalid HTTP request missing authority in URI and host header"),
557 },
558 };
559
560 remove_forbidden_headers(view, &mut parts.headers);
561 Ok(Self {
562 parts,
563 authority,
564 scheme,
565 body,
566 })
567 }
568}
569
570pub struct HostResponseOutparam {
572 pub result:
574 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
575}
576
577pub struct HostOutgoingResponse {
579 pub status: http::StatusCode,
581 pub headers: FieldMap,
583 pub body: Option<HyperOutgoingBody>,
585}
586
587impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
588 type Error = http::Error;
589
590 fn try_from(
591 resp: HostOutgoingResponse,
592 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
593 use http_body_util::Empty;
594
595 let mut builder = hyper::Response::builder().status(resp.status);
596
597 *builder.headers_mut().unwrap() = resp.headers;
598
599 match resp.body {
600 Some(body) => builder.body(body),
601 None => builder.body(
602 Empty::<bytes::Bytes>::new()
603 .map_err(|_| unreachable!("Infallible error"))
604 .boxed_unsync(),
605 ),
606 }
607 }
608}
609
610#[derive(Debug)]
612pub struct HostOutgoingRequest {
613 pub method: Method,
615 pub scheme: Option<Scheme>,
617 pub authority: Option<String>,
619 pub path_with_query: Option<String>,
621 pub headers: FieldMap,
623 pub body: Option<HyperOutgoingBody>,
625}
626
627#[derive(Debug, Default)]
629pub struct HostRequestOptions {
630 pub connect_timeout: Option<std::time::Duration>,
632 pub first_byte_timeout: Option<std::time::Duration>,
634 pub between_bytes_timeout: Option<std::time::Duration>,
636}
637
638#[derive(Debug)]
640pub struct HostIncomingResponse {
641 pub status: u16,
643 pub headers: FieldMap,
645 pub body: Option<HostIncomingBody>,
647}
648
649#[derive(Debug)]
651pub enum HostFields {
652 Ref {
654 parent: u32,
656
657 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
663 },
664 Owned {
666 fields: FieldMap,
668 },
669}
670
671pub type FieldMap = hyper::HeaderMap;
673
674pub type FutureIncomingResponseHandle =
676 AbortOnDropJoinHandle<wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>>;
677
678#[derive(Debug)]
680pub struct IncomingResponse {
681 pub resp: hyper::Response<HyperIncomingBody>,
683 pub worker: Option<AbortOnDropJoinHandle<()>>,
685 pub between_bytes_timeout: std::time::Duration,
687}
688
689#[derive(Debug)]
691pub enum HostFutureIncomingResponse {
692 Pending(FutureIncomingResponseHandle),
694 Ready(wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>),
698 Consumed,
700}
701
702impl HostFutureIncomingResponse {
703 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
705 Self::Pending(handle)
706 }
707
708 pub fn ready(result: wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
710 Self::Ready(result)
711 }
712
713 pub fn is_ready(&self) -> bool {
715 matches!(self, Self::Ready(_))
716 }
717
718 pub fn unwrap_ready(self) -> wasmtime::Result<Result<IncomingResponse, types::ErrorCode>> {
720 match self {
721 Self::Ready(res) => res,
722 Self::Pending(_) | Self::Consumed => {
723 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
724 }
725 }
726 }
727}
728
729#[async_trait::async_trait]
730impl Pollable for HostFutureIncomingResponse {
731 async fn ready(&mut self) {
732 if let Self::Pending(handle) = self {
733 *self = Self::Ready(handle.await);
734 }
735 }
736}