1use crate::{
5 bindings::http::types::{self, Method, Scheme},
6 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
7};
8use anyhow::bail;
9use bytes::Bytes;
10use http_body_util::BodyExt;
11use hyper::body::Body;
12use hyper::header::HeaderName;
13use std::any::Any;
14use std::time::Duration;
15use wasmtime::component::{Resource, ResourceTable};
16use wasmtime_wasi::p2::Pollable;
17use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
18
19#[cfg(feature = "default-send-request")]
20use {
21 crate::io::TokioIo,
22 crate::{error::dns_error, hyper_request_error},
23 tokio::net::TcpStream,
24 tokio::time::timeout,
25};
26
27#[derive(Debug)]
29pub struct WasiHttpCtx {
30 _priv: (),
31}
32
33impl WasiHttpCtx {
34 pub fn new() -> Self {
36 Self { _priv: () }
37 }
38}
39
40pub trait WasiHttpView {
82 fn ctx(&mut self) -> &mut WasiHttpCtx;
84
85 fn table(&mut self) -> &mut ResourceTable;
87
88 fn new_incoming_request<B>(
90 &mut self,
91 scheme: Scheme,
92 req: hyper::Request<B>,
93 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
94 where
95 B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
96 Self: Sized,
97 {
98 let (parts, body) = req.into_parts();
99 let body = body.map_err(crate::hyper_response_error).boxed();
100 let body = HostIncomingBody::new(
101 body,
102 std::time::Duration::from_millis(600 * 1000),
104 );
105 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
106 Ok(self.table().push(incoming_req)?)
107 }
108
109 fn new_response_outparam(
111 &mut self,
112 result: tokio::sync::oneshot::Sender<
113 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
114 >,
115 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
116 let id = self.table().push(HostResponseOutparam { result })?;
117 Ok(id)
118 }
119
120 #[cfg(feature = "default-send-request")]
122 fn send_request(
123 &mut self,
124 request: hyper::Request<HyperOutgoingBody>,
125 config: OutgoingRequestConfig,
126 ) -> crate::HttpResult<HostFutureIncomingResponse> {
127 Ok(default_send_request(request, config))
128 }
129
130 #[cfg(not(feature = "default-send-request"))]
132 fn send_request(
133 &mut self,
134 request: hyper::Request<HyperOutgoingBody>,
135 config: OutgoingRequestConfig,
136 ) -> crate::HttpResult<HostFutureIncomingResponse>;
137
138 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
140 DEFAULT_FORBIDDEN_HEADERS.contains(name)
141 }
142
143 fn outgoing_body_buffer_chunks(&mut self) -> usize {
147 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
148 }
149
150 fn outgoing_body_chunk_size(&mut self) -> usize {
153 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
154 }
155}
156
157pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
159pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
161
162impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
163 fn ctx(&mut self) -> &mut WasiHttpCtx {
164 T::ctx(self)
165 }
166
167 fn table(&mut self) -> &mut ResourceTable {
168 T::table(self)
169 }
170
171 fn new_response_outparam(
172 &mut self,
173 result: tokio::sync::oneshot::Sender<
174 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
175 >,
176 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
177 T::new_response_outparam(self, result)
178 }
179
180 fn send_request(
181 &mut self,
182 request: hyper::Request<HyperOutgoingBody>,
183 config: OutgoingRequestConfig,
184 ) -> crate::HttpResult<HostFutureIncomingResponse> {
185 T::send_request(self, request, config)
186 }
187
188 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
189 T::is_forbidden_header(self, name)
190 }
191
192 fn outgoing_body_buffer_chunks(&mut self) -> usize {
193 T::outgoing_body_buffer_chunks(self)
194 }
195
196 fn outgoing_body_chunk_size(&mut self) -> usize {
197 T::outgoing_body_chunk_size(self)
198 }
199}
200
201impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
202 fn ctx(&mut self) -> &mut WasiHttpCtx {
203 T::ctx(self)
204 }
205
206 fn table(&mut self) -> &mut ResourceTable {
207 T::table(self)
208 }
209
210 fn new_response_outparam(
211 &mut self,
212 result: tokio::sync::oneshot::Sender<
213 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
214 >,
215 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
216 T::new_response_outparam(self, result)
217 }
218
219 fn send_request(
220 &mut self,
221 request: hyper::Request<HyperOutgoingBody>,
222 config: OutgoingRequestConfig,
223 ) -> crate::HttpResult<HostFutureIncomingResponse> {
224 T::send_request(self, request, config)
225 }
226
227 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
228 T::is_forbidden_header(self, name)
229 }
230
231 fn outgoing_body_buffer_chunks(&mut self) -> usize {
232 T::outgoing_body_buffer_chunks(self)
233 }
234
235 fn outgoing_body_chunk_size(&mut self) -> usize {
236 T::outgoing_body_chunk_size(self)
237 }
238}
239
240#[repr(transparent)]
253pub struct WasiHttpImpl<T>(pub T);
254
255impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
256 fn ctx(&mut self) -> &mut WasiHttpCtx {
257 self.0.ctx()
258 }
259
260 fn table(&mut self) -> &mut ResourceTable {
261 self.0.table()
262 }
263
264 fn new_response_outparam(
265 &mut self,
266 result: tokio::sync::oneshot::Sender<
267 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
268 >,
269 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
270 self.0.new_response_outparam(result)
271 }
272
273 fn send_request(
274 &mut self,
275 request: hyper::Request<HyperOutgoingBody>,
276 config: OutgoingRequestConfig,
277 ) -> crate::HttpResult<HostFutureIncomingResponse> {
278 self.0.send_request(request, config)
279 }
280
281 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
282 self.0.is_forbidden_header(name)
283 }
284
285 fn outgoing_body_buffer_chunks(&mut self) -> usize {
286 self.0.outgoing_body_buffer_chunks()
287 }
288
289 fn outgoing_body_chunk_size(&mut self) -> usize {
290 self.0.outgoing_body_chunk_size()
291 }
292}
293
294pub const DEFAULT_FORBIDDEN_HEADERS: [http::header::HeaderName; 9] = [
297 hyper::header::CONNECTION,
298 HeaderName::from_static("keep-alive"),
299 hyper::header::PROXY_AUTHENTICATE,
300 hyper::header::PROXY_AUTHORIZATION,
301 HeaderName::from_static("proxy-connection"),
302 hyper::header::TRANSFER_ENCODING,
303 hyper::header::UPGRADE,
304 hyper::header::HOST,
305 HeaderName::from_static("http2-settings"),
306];
307
308pub(crate) fn remove_forbidden_headers(
310 view: &mut dyn WasiHttpView,
311 headers: &mut hyper::HeaderMap,
312) {
313 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
314 if view.is_forbidden_header(name) {
315 Some(name.clone())
316 } else {
317 None
318 }
319 }));
320
321 for name in forbidden_keys {
322 headers.remove(name);
323 }
324}
325
326pub struct OutgoingRequestConfig {
328 pub use_tls: bool,
330 pub connect_timeout: Duration,
332 pub first_byte_timeout: Duration,
334 pub between_bytes_timeout: Duration,
336}
337
338#[cfg(feature = "default-send-request")]
343pub fn default_send_request(
344 request: hyper::Request<HyperOutgoingBody>,
345 config: OutgoingRequestConfig,
346) -> HostFutureIncomingResponse {
347 let handle = wasmtime_wasi::runtime::spawn(async move {
348 Ok(default_send_request_handler(request, config).await)
349 });
350 HostFutureIncomingResponse::pending(handle)
351}
352
353#[cfg(feature = "default-send-request")]
358pub async fn default_send_request_handler(
359 mut request: hyper::Request<HyperOutgoingBody>,
360 OutgoingRequestConfig {
361 use_tls,
362 connect_timeout,
363 first_byte_timeout,
364 between_bytes_timeout,
365 }: OutgoingRequestConfig,
366) -> Result<IncomingResponse, types::ErrorCode> {
367 let authority = if let Some(authority) = request.uri().authority() {
368 if authority.port().is_some() {
369 authority.to_string()
370 } else {
371 let port = if use_tls { 443 } else { 80 };
372 format!("{}:{port}", authority.to_string())
373 }
374 } else {
375 return Err(types::ErrorCode::HttpRequestUriInvalid);
376 };
377 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
378 .await
379 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
380 .map_err(|e| match e.kind() {
381 std::io::ErrorKind::AddrNotAvailable => {
382 dns_error("address not available".to_string(), 0)
383 }
384
385 _ => {
386 if e.to_string()
387 .starts_with("failed to lookup address information")
388 {
389 dns_error("address not available".to_string(), 0)
390 } else {
391 types::ErrorCode::ConnectionRefused
392 }
393 }
394 })?;
395
396 let (mut sender, worker) = if use_tls {
397 use rustls::pki_types::ServerName;
398
399 let root_cert_store = rustls::RootCertStore {
401 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
402 };
403 let config = rustls::ClientConfig::builder()
404 .with_root_certificates(root_cert_store)
405 .with_no_client_auth();
406 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
407 let mut parts = authority.split(":");
408 let host = parts.next().unwrap_or(&authority);
409 let domain = ServerName::try_from(host)
410 .map_err(|e| {
411 tracing::warn!("dns lookup error: {e:?}");
412 dns_error("invalid dns name".to_string(), 0)
413 })?
414 .to_owned();
415 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
416 tracing::warn!("tls protocol error: {e:?}");
417 types::ErrorCode::TlsProtocolError
418 })?;
419 let stream = TokioIo::new(stream);
420
421 let (sender, conn) = timeout(
422 connect_timeout,
423 hyper::client::conn::http1::handshake(stream),
424 )
425 .await
426 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
427 .map_err(hyper_request_error)?;
428
429 let worker = wasmtime_wasi::runtime::spawn(async move {
430 match conn.await {
431 Ok(()) => {}
432 Err(e) => tracing::warn!("dropping error {e}"),
435 }
436 });
437
438 (sender, worker)
439 } else {
440 let tcp_stream = TokioIo::new(tcp_stream);
441 let (sender, conn) = timeout(
442 connect_timeout,
443 hyper::client::conn::http1::handshake(tcp_stream),
445 )
446 .await
447 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
448 .map_err(hyper_request_error)?;
449
450 let worker = wasmtime_wasi::runtime::spawn(async move {
451 match conn.await {
452 Ok(()) => {}
453 Err(e) => tracing::warn!("dropping error {e}"),
455 }
456 });
457
458 (sender, worker)
459 };
460
461 *request.uri_mut() = http::Uri::builder()
465 .path_and_query(
466 request
467 .uri()
468 .path_and_query()
469 .map(|p| p.as_str())
470 .unwrap_or("/"),
471 )
472 .build()
473 .expect("comes from valid request");
474
475 let resp = timeout(first_byte_timeout, sender.send_request(request))
476 .await
477 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
478 .map_err(hyper_request_error)?
479 .map(|body| body.map_err(hyper_request_error).boxed());
480
481 Ok(IncomingResponse {
482 resp,
483 worker: Some(worker),
484 between_bytes_timeout,
485 })
486}
487
488impl From<http::Method> for types::Method {
489 fn from(method: http::Method) -> Self {
490 if method == http::Method::GET {
491 types::Method::Get
492 } else if method == hyper::Method::HEAD {
493 types::Method::Head
494 } else if method == hyper::Method::POST {
495 types::Method::Post
496 } else if method == hyper::Method::PUT {
497 types::Method::Put
498 } else if method == hyper::Method::DELETE {
499 types::Method::Delete
500 } else if method == hyper::Method::CONNECT {
501 types::Method::Connect
502 } else if method == hyper::Method::OPTIONS {
503 types::Method::Options
504 } else if method == hyper::Method::TRACE {
505 types::Method::Trace
506 } else if method == hyper::Method::PATCH {
507 types::Method::Patch
508 } else {
509 types::Method::Other(method.to_string())
510 }
511 }
512}
513
514impl TryInto<http::Method> for types::Method {
515 type Error = http::method::InvalidMethod;
516
517 fn try_into(self) -> Result<http::Method, Self::Error> {
518 match self {
519 Method::Get => Ok(http::Method::GET),
520 Method::Head => Ok(http::Method::HEAD),
521 Method::Post => Ok(http::Method::POST),
522 Method::Put => Ok(http::Method::PUT),
523 Method::Delete => Ok(http::Method::DELETE),
524 Method::Connect => Ok(http::Method::CONNECT),
525 Method::Options => Ok(http::Method::OPTIONS),
526 Method::Trace => Ok(http::Method::TRACE),
527 Method::Patch => Ok(http::Method::PATCH),
528 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
529 }
530 }
531}
532
533#[derive(Debug)]
535pub struct HostIncomingRequest {
536 pub(crate) parts: http::request::Parts,
537 pub(crate) scheme: Scheme,
538 pub(crate) authority: String,
539 pub body: Option<HostIncomingBody>,
541}
542
543impl HostIncomingRequest {
544 pub fn new(
546 view: &mut dyn WasiHttpView,
547 mut parts: http::request::Parts,
548 scheme: Scheme,
549 body: Option<HostIncomingBody>,
550 ) -> anyhow::Result<Self> {
551 let authority = match parts.uri.authority() {
552 Some(authority) => authority.to_string(),
553 None => match parts.headers.get(http::header::HOST) {
554 Some(host) => host.to_str()?.to_string(),
555 None => bail!("invalid HTTP request missing authority in URI and host header"),
556 },
557 };
558
559 remove_forbidden_headers(view, &mut parts.headers);
560 Ok(Self {
561 parts,
562 authority,
563 scheme,
564 body,
565 })
566 }
567}
568
569pub struct HostResponseOutparam {
571 pub result:
573 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
574}
575
576pub struct HostOutgoingResponse {
578 pub status: http::StatusCode,
580 pub headers: FieldMap,
582 pub body: Option<HyperOutgoingBody>,
584}
585
586impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
587 type Error = http::Error;
588
589 fn try_from(
590 resp: HostOutgoingResponse,
591 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
592 use http_body_util::Empty;
593
594 let mut builder = hyper::Response::builder().status(resp.status);
595
596 *builder.headers_mut().unwrap() = resp.headers;
597
598 match resp.body {
599 Some(body) => builder.body(body),
600 None => builder.body(
601 Empty::<bytes::Bytes>::new()
602 .map_err(|_| unreachable!("Infallible error"))
603 .boxed(),
604 ),
605 }
606 }
607}
608
609#[derive(Debug)]
611pub struct HostOutgoingRequest {
612 pub method: Method,
614 pub scheme: Option<Scheme>,
616 pub authority: Option<String>,
618 pub path_with_query: Option<String>,
620 pub headers: FieldMap,
622 pub body: Option<HyperOutgoingBody>,
624}
625
626#[derive(Debug, Default)]
628pub struct HostRequestOptions {
629 pub connect_timeout: Option<std::time::Duration>,
631 pub first_byte_timeout: Option<std::time::Duration>,
633 pub between_bytes_timeout: Option<std::time::Duration>,
635}
636
637#[derive(Debug)]
639pub struct HostIncomingResponse {
640 pub status: u16,
642 pub headers: FieldMap,
644 pub body: Option<HostIncomingBody>,
646}
647
648#[derive(Debug)]
650pub enum HostFields {
651 Ref {
653 parent: u32,
655
656 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
662 },
663 Owned {
665 fields: FieldMap,
667 },
668}
669
670pub type FieldMap = hyper::HeaderMap;
672
673pub type FutureIncomingResponseHandle =
675 AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
676
677#[derive(Debug)]
679pub struct IncomingResponse {
680 pub resp: hyper::Response<HyperIncomingBody>,
682 pub worker: Option<AbortOnDropJoinHandle<()>>,
684 pub between_bytes_timeout: std::time::Duration,
686}
687
688#[derive(Debug)]
690pub enum HostFutureIncomingResponse {
691 Pending(FutureIncomingResponseHandle),
693 Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
697 Consumed,
699}
700
701impl HostFutureIncomingResponse {
702 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
704 Self::Pending(handle)
705 }
706
707 pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
709 Self::Ready(result)
710 }
711
712 pub fn is_ready(&self) -> bool {
714 matches!(self, Self::Ready(_))
715 }
716
717 pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
719 match self {
720 Self::Ready(res) => res,
721 Self::Pending(_) | Self::Consumed => {
722 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
723 }
724 }
725 }
726}
727
728#[async_trait::async_trait]
729impl Pollable for HostFutureIncomingResponse {
730 async fn ready(&mut self) {
731 if let Self::Pending(handle) = self {
732 *self = Self::Ready(handle.await);
733 }
734 }
735}