1use crate::io::TokioIo;
5use crate::{
6 bindings::http::types::{self, Method, Scheme},
7 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
8 error::dns_error,
9 hyper_request_error,
10};
11use anyhow::bail;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Body;
15use hyper::header::HeaderName;
16use std::any::Any;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20use wasmtime::component::{Resource, ResourceTable};
21use wasmtime_wasi::{runtime::AbortOnDropJoinHandle, IoImpl, IoView, Pollable};
22
23#[derive(Debug)]
25pub struct WasiHttpCtx {
26 _priv: (),
27}
28
29impl WasiHttpCtx {
30 pub fn new() -> Self {
32 Self { _priv: () }
33 }
34}
35
36pub trait WasiHttpView: IoView {
78 fn ctx(&mut self) -> &mut WasiHttpCtx;
80
81 fn new_incoming_request<B>(
83 &mut self,
84 scheme: Scheme,
85 req: hyper::Request<B>,
86 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
87 where
88 B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
89 Self: Sized,
90 {
91 let (parts, body) = req.into_parts();
92 let body = body.map_err(crate::hyper_response_error).boxed();
93 let body = HostIncomingBody::new(
94 body,
95 std::time::Duration::from_millis(600 * 1000),
97 );
98 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
99 Ok(self.table().push(incoming_req)?)
100 }
101
102 fn new_response_outparam(
104 &mut self,
105 result: tokio::sync::oneshot::Sender<
106 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
107 >,
108 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
109 let id = self.table().push(HostResponseOutparam { result })?;
110 Ok(id)
111 }
112
113 fn send_request(
115 &mut self,
116 request: hyper::Request<HyperOutgoingBody>,
117 config: OutgoingRequestConfig,
118 ) -> crate::HttpResult<HostFutureIncomingResponse> {
119 Ok(default_send_request(request, config))
120 }
121
122 fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
124 false
125 }
126
127 fn outgoing_body_buffer_chunks(&mut self) -> usize {
131 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
132 }
133
134 fn outgoing_body_chunk_size(&mut self) -> usize {
137 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
138 }
139}
140
141pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
143pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
145
146impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
147 fn ctx(&mut self) -> &mut WasiHttpCtx {
148 T::ctx(self)
149 }
150
151 fn new_response_outparam(
152 &mut self,
153 result: tokio::sync::oneshot::Sender<
154 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
155 >,
156 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
157 T::new_response_outparam(self, result)
158 }
159
160 fn send_request(
161 &mut self,
162 request: hyper::Request<HyperOutgoingBody>,
163 config: OutgoingRequestConfig,
164 ) -> crate::HttpResult<HostFutureIncomingResponse> {
165 T::send_request(self, request, config)
166 }
167
168 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
169 T::is_forbidden_header(self, name)
170 }
171
172 fn outgoing_body_buffer_chunks(&mut self) -> usize {
173 T::outgoing_body_buffer_chunks(self)
174 }
175
176 fn outgoing_body_chunk_size(&mut self) -> usize {
177 T::outgoing_body_chunk_size(self)
178 }
179}
180
181impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
182 fn ctx(&mut self) -> &mut WasiHttpCtx {
183 T::ctx(self)
184 }
185
186 fn new_response_outparam(
187 &mut self,
188 result: tokio::sync::oneshot::Sender<
189 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
190 >,
191 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
192 T::new_response_outparam(self, result)
193 }
194
195 fn send_request(
196 &mut self,
197 request: hyper::Request<HyperOutgoingBody>,
198 config: OutgoingRequestConfig,
199 ) -> crate::HttpResult<HostFutureIncomingResponse> {
200 T::send_request(self, request, config)
201 }
202
203 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
204 T::is_forbidden_header(self, name)
205 }
206
207 fn outgoing_body_buffer_chunks(&mut self) -> usize {
208 T::outgoing_body_buffer_chunks(self)
209 }
210
211 fn outgoing_body_chunk_size(&mut self) -> usize {
212 T::outgoing_body_chunk_size(self)
213 }
214}
215
216#[repr(transparent)]
229pub struct WasiHttpImpl<T>(pub IoImpl<T>);
230
231impl<T: IoView> IoView for WasiHttpImpl<T> {
232 fn table(&mut self) -> &mut ResourceTable {
233 T::table(&mut self.0 .0)
234 }
235}
236impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
237 fn ctx(&mut self) -> &mut WasiHttpCtx {
238 self.0 .0.ctx()
239 }
240
241 fn new_response_outparam(
242 &mut self,
243 result: tokio::sync::oneshot::Sender<
244 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
245 >,
246 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
247 self.0 .0.new_response_outparam(result)
248 }
249
250 fn send_request(
251 &mut self,
252 request: hyper::Request<HyperOutgoingBody>,
253 config: OutgoingRequestConfig,
254 ) -> crate::HttpResult<HostFutureIncomingResponse> {
255 self.0 .0.send_request(request, config)
256 }
257
258 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
259 self.0 .0.is_forbidden_header(name)
260 }
261
262 fn outgoing_body_buffer_chunks(&mut self) -> usize {
263 self.0 .0.outgoing_body_buffer_chunks()
264 }
265
266 fn outgoing_body_chunk_size(&mut self) -> usize {
267 self.0 .0.outgoing_body_chunk_size()
268 }
269}
270
271pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
273 static FORBIDDEN_HEADERS: [HeaderName; 10] = [
274 hyper::header::CONNECTION,
275 HeaderName::from_static("keep-alive"),
276 hyper::header::PROXY_AUTHENTICATE,
277 hyper::header::PROXY_AUTHORIZATION,
278 HeaderName::from_static("proxy-connection"),
279 hyper::header::TE,
280 hyper::header::TRANSFER_ENCODING,
281 hyper::header::UPGRADE,
282 hyper::header::HOST,
283 HeaderName::from_static("http2-settings"),
284 ];
285
286 FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
287}
288
289pub(crate) fn remove_forbidden_headers(
291 view: &mut dyn WasiHttpView,
292 headers: &mut hyper::HeaderMap,
293) {
294 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
295 if is_forbidden_header(view, name) {
296 Some(name.clone())
297 } else {
298 None
299 }
300 }));
301
302 for name in forbidden_keys {
303 headers.remove(name);
304 }
305}
306
307pub struct OutgoingRequestConfig {
309 pub use_tls: bool,
311 pub connect_timeout: Duration,
313 pub first_byte_timeout: Duration,
315 pub between_bytes_timeout: Duration,
317}
318
319pub fn default_send_request(
324 request: hyper::Request<HyperOutgoingBody>,
325 config: OutgoingRequestConfig,
326) -> HostFutureIncomingResponse {
327 let handle = wasmtime_wasi::runtime::spawn(async move {
328 Ok(default_send_request_handler(request, config).await)
329 });
330 HostFutureIncomingResponse::pending(handle)
331}
332
333pub async fn default_send_request_handler(
338 mut request: hyper::Request<HyperOutgoingBody>,
339 OutgoingRequestConfig {
340 use_tls,
341 connect_timeout,
342 first_byte_timeout,
343 between_bytes_timeout,
344 }: OutgoingRequestConfig,
345) -> Result<IncomingResponse, types::ErrorCode> {
346 let authority = if let Some(authority) = request.uri().authority() {
347 if authority.port().is_some() {
348 authority.to_string()
349 } else {
350 let port = if use_tls { 443 } else { 80 };
351 format!("{}:{port}", authority.to_string())
352 }
353 } else {
354 return Err(types::ErrorCode::HttpRequestUriInvalid);
355 };
356 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
357 .await
358 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
359 .map_err(|e| match e.kind() {
360 std::io::ErrorKind::AddrNotAvailable => {
361 dns_error("address not available".to_string(), 0)
362 }
363
364 _ => {
365 if e.to_string()
366 .starts_with("failed to lookup address information")
367 {
368 dns_error("address not available".to_string(), 0)
369 } else {
370 types::ErrorCode::ConnectionRefused
371 }
372 }
373 })?;
374
375 let (mut sender, worker) = if use_tls {
376 use rustls::pki_types::ServerName;
377
378 let root_cert_store = rustls::RootCertStore {
380 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
381 };
382 let config = rustls::ClientConfig::builder()
383 .with_root_certificates(root_cert_store)
384 .with_no_client_auth();
385 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
386 let mut parts = authority.split(":");
387 let host = parts.next().unwrap_or(&authority);
388 let domain = ServerName::try_from(host)
389 .map_err(|e| {
390 tracing::warn!("dns lookup error: {e:?}");
391 dns_error("invalid dns name".to_string(), 0)
392 })?
393 .to_owned();
394 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
395 tracing::warn!("tls protocol error: {e:?}");
396 types::ErrorCode::TlsProtocolError
397 })?;
398 let stream = TokioIo::new(stream);
399
400 let (sender, conn) = timeout(
401 connect_timeout,
402 hyper::client::conn::http1::handshake(stream),
403 )
404 .await
405 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
406 .map_err(hyper_request_error)?;
407
408 let worker = wasmtime_wasi::runtime::spawn(async move {
409 match conn.await {
410 Ok(()) => {}
411 Err(e) => tracing::warn!("dropping error {e}"),
414 }
415 });
416
417 (sender, worker)
418 } else {
419 let tcp_stream = TokioIo::new(tcp_stream);
420 let (sender, conn) = timeout(
421 connect_timeout,
422 hyper::client::conn::http1::handshake(tcp_stream),
424 )
425 .await
426 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
427 .map_err(hyper_request_error)?;
428
429 let worker = wasmtime_wasi::runtime::spawn(async move {
430 match conn.await {
431 Ok(()) => {}
432 Err(e) => tracing::warn!("dropping error {e}"),
434 }
435 });
436
437 (sender, worker)
438 };
439
440 *request.uri_mut() = http::Uri::builder()
444 .path_and_query(
445 request
446 .uri()
447 .path_and_query()
448 .map(|p| p.as_str())
449 .unwrap_or("/"),
450 )
451 .build()
452 .expect("comes from valid request");
453
454 let resp = timeout(first_byte_timeout, sender.send_request(request))
455 .await
456 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
457 .map_err(hyper_request_error)?
458 .map(|body| body.map_err(hyper_request_error).boxed());
459
460 Ok(IncomingResponse {
461 resp,
462 worker: Some(worker),
463 between_bytes_timeout,
464 })
465}
466
467impl From<http::Method> for types::Method {
468 fn from(method: http::Method) -> Self {
469 if method == http::Method::GET {
470 types::Method::Get
471 } else if method == hyper::Method::HEAD {
472 types::Method::Head
473 } else if method == hyper::Method::POST {
474 types::Method::Post
475 } else if method == hyper::Method::PUT {
476 types::Method::Put
477 } else if method == hyper::Method::DELETE {
478 types::Method::Delete
479 } else if method == hyper::Method::CONNECT {
480 types::Method::Connect
481 } else if method == hyper::Method::OPTIONS {
482 types::Method::Options
483 } else if method == hyper::Method::TRACE {
484 types::Method::Trace
485 } else if method == hyper::Method::PATCH {
486 types::Method::Patch
487 } else {
488 types::Method::Other(method.to_string())
489 }
490 }
491}
492
493impl TryInto<http::Method> for types::Method {
494 type Error = http::method::InvalidMethod;
495
496 fn try_into(self) -> Result<http::Method, Self::Error> {
497 match self {
498 Method::Get => Ok(http::Method::GET),
499 Method::Head => Ok(http::Method::HEAD),
500 Method::Post => Ok(http::Method::POST),
501 Method::Put => Ok(http::Method::PUT),
502 Method::Delete => Ok(http::Method::DELETE),
503 Method::Connect => Ok(http::Method::CONNECT),
504 Method::Options => Ok(http::Method::OPTIONS),
505 Method::Trace => Ok(http::Method::TRACE),
506 Method::Patch => Ok(http::Method::PATCH),
507 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
508 }
509 }
510}
511
512#[derive(Debug)]
514pub struct HostIncomingRequest {
515 pub(crate) parts: http::request::Parts,
516 pub(crate) scheme: Scheme,
517 pub(crate) authority: String,
518 pub body: Option<HostIncomingBody>,
520}
521
522impl HostIncomingRequest {
523 pub fn new(
525 view: &mut dyn WasiHttpView,
526 mut parts: http::request::Parts,
527 scheme: Scheme,
528 body: Option<HostIncomingBody>,
529 ) -> anyhow::Result<Self> {
530 let authority = match parts.uri.authority() {
531 Some(authority) => authority.to_string(),
532 None => match parts.headers.get(http::header::HOST) {
533 Some(host) => host.to_str()?.to_string(),
534 None => bail!("invalid HTTP request missing authority in URI and host header"),
535 },
536 };
537
538 remove_forbidden_headers(view, &mut parts.headers);
539 Ok(Self {
540 parts,
541 authority,
542 scheme,
543 body,
544 })
545 }
546}
547
548pub struct HostResponseOutparam {
550 pub result:
552 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
553}
554
555pub struct HostOutgoingResponse {
557 pub status: http::StatusCode,
559 pub headers: FieldMap,
561 pub body: Option<HyperOutgoingBody>,
563}
564
565impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
566 type Error = http::Error;
567
568 fn try_from(
569 resp: HostOutgoingResponse,
570 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
571 use http_body_util::Empty;
572
573 let mut builder = hyper::Response::builder().status(resp.status);
574
575 *builder.headers_mut().unwrap() = resp.headers;
576
577 match resp.body {
578 Some(body) => builder.body(body),
579 None => builder.body(
580 Empty::<bytes::Bytes>::new()
581 .map_err(|_| unreachable!("Infallible error"))
582 .boxed(),
583 ),
584 }
585 }
586}
587
588#[derive(Debug)]
590pub struct HostOutgoingRequest {
591 pub method: Method,
593 pub scheme: Option<Scheme>,
595 pub authority: Option<String>,
597 pub path_with_query: Option<String>,
599 pub headers: FieldMap,
601 pub body: Option<HyperOutgoingBody>,
603}
604
605#[derive(Debug, Default)]
607pub struct HostRequestOptions {
608 pub connect_timeout: Option<std::time::Duration>,
610 pub first_byte_timeout: Option<std::time::Duration>,
612 pub between_bytes_timeout: Option<std::time::Duration>,
614}
615
616#[derive(Debug)]
618pub struct HostIncomingResponse {
619 pub status: u16,
621 pub headers: FieldMap,
623 pub body: Option<HostIncomingBody>,
625}
626
627#[derive(Debug)]
629pub enum HostFields {
630 Ref {
632 parent: u32,
634
635 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
641 },
642 Owned {
644 fields: FieldMap,
646 },
647}
648
649pub type FieldMap = hyper::HeaderMap;
651
652pub type FutureIncomingResponseHandle =
654 AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
655
656#[derive(Debug)]
658pub struct IncomingResponse {
659 pub resp: hyper::Response<HyperIncomingBody>,
661 pub worker: Option<AbortOnDropJoinHandle<()>>,
663 pub between_bytes_timeout: std::time::Duration,
665}
666
667#[derive(Debug)]
669pub enum HostFutureIncomingResponse {
670 Pending(FutureIncomingResponseHandle),
672 Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
676 Consumed,
678}
679
680impl HostFutureIncomingResponse {
681 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
683 Self::Pending(handle)
684 }
685
686 pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
688 Self::Ready(result)
689 }
690
691 pub fn is_ready(&self) -> bool {
693 matches!(self, Self::Ready(_))
694 }
695
696 pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
698 match self {
699 Self::Ready(res) => res,
700 Self::Pending(_) | Self::Consumed => {
701 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
702 }
703 }
704 }
705}
706
707#[async_trait::async_trait]
708impl Pollable for HostFutureIncomingResponse {
709 async fn ready(&mut self) {
710 if let Self::Pending(handle) = self {
711 *self = Self::Ready(handle.await);
712 }
713 }
714}