1use crate::io::TokioIo;
5use crate::{
6 bindings::http::types::{self, Method, Scheme},
7 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
8 error::dns_error,
9 hyper_request_error,
10};
11use anyhow::bail;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Body;
15use hyper::header::HeaderName;
16use std::any::Any;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20use wasmtime::component::{Resource, ResourceTable};
21use wasmtime_wasi::p2::{IoImpl, IoView, Pollable};
22use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
23
24#[derive(Debug)]
26pub struct WasiHttpCtx {
27 _priv: (),
28}
29
30impl WasiHttpCtx {
31 pub fn new() -> Self {
33 Self { _priv: () }
34 }
35}
36
37pub trait WasiHttpView: IoView {
79 fn ctx(&mut self) -> &mut WasiHttpCtx;
81
82 fn new_incoming_request<B>(
84 &mut self,
85 scheme: Scheme,
86 req: hyper::Request<B>,
87 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
88 where
89 B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
90 Self: Sized,
91 {
92 let (parts, body) = req.into_parts();
93 let body = body.map_err(crate::hyper_response_error).boxed();
94 let body = HostIncomingBody::new(
95 body,
96 std::time::Duration::from_millis(600 * 1000),
98 );
99 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
100 Ok(self.table().push(incoming_req)?)
101 }
102
103 fn new_response_outparam(
105 &mut self,
106 result: tokio::sync::oneshot::Sender<
107 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
108 >,
109 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
110 let id = self.table().push(HostResponseOutparam { result })?;
111 Ok(id)
112 }
113
114 fn send_request(
116 &mut self,
117 request: hyper::Request<HyperOutgoingBody>,
118 config: OutgoingRequestConfig,
119 ) -> crate::HttpResult<HostFutureIncomingResponse> {
120 Ok(default_send_request(request, config))
121 }
122
123 fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
125 false
126 }
127
128 fn outgoing_body_buffer_chunks(&mut self) -> usize {
132 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
133 }
134
135 fn outgoing_body_chunk_size(&mut self) -> usize {
138 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
139 }
140}
141
142pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
144pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
146
147impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
148 fn ctx(&mut self) -> &mut WasiHttpCtx {
149 T::ctx(self)
150 }
151
152 fn new_response_outparam(
153 &mut self,
154 result: tokio::sync::oneshot::Sender<
155 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
156 >,
157 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
158 T::new_response_outparam(self, result)
159 }
160
161 fn send_request(
162 &mut self,
163 request: hyper::Request<HyperOutgoingBody>,
164 config: OutgoingRequestConfig,
165 ) -> crate::HttpResult<HostFutureIncomingResponse> {
166 T::send_request(self, request, config)
167 }
168
169 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
170 T::is_forbidden_header(self, name)
171 }
172
173 fn outgoing_body_buffer_chunks(&mut self) -> usize {
174 T::outgoing_body_buffer_chunks(self)
175 }
176
177 fn outgoing_body_chunk_size(&mut self) -> usize {
178 T::outgoing_body_chunk_size(self)
179 }
180}
181
182impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
183 fn ctx(&mut self) -> &mut WasiHttpCtx {
184 T::ctx(self)
185 }
186
187 fn new_response_outparam(
188 &mut self,
189 result: tokio::sync::oneshot::Sender<
190 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
191 >,
192 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
193 T::new_response_outparam(self, result)
194 }
195
196 fn send_request(
197 &mut self,
198 request: hyper::Request<HyperOutgoingBody>,
199 config: OutgoingRequestConfig,
200 ) -> crate::HttpResult<HostFutureIncomingResponse> {
201 T::send_request(self, request, config)
202 }
203
204 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
205 T::is_forbidden_header(self, name)
206 }
207
208 fn outgoing_body_buffer_chunks(&mut self) -> usize {
209 T::outgoing_body_buffer_chunks(self)
210 }
211
212 fn outgoing_body_chunk_size(&mut self) -> usize {
213 T::outgoing_body_chunk_size(self)
214 }
215}
216
217#[repr(transparent)]
230pub struct WasiHttpImpl<T>(pub IoImpl<T>);
231
232impl<T: IoView> IoView for WasiHttpImpl<T> {
233 fn table(&mut self) -> &mut ResourceTable {
234 T::table(&mut self.0.0)
235 }
236}
237impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
238 fn ctx(&mut self) -> &mut WasiHttpCtx {
239 self.0.0.ctx()
240 }
241
242 fn new_response_outparam(
243 &mut self,
244 result: tokio::sync::oneshot::Sender<
245 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
246 >,
247 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
248 self.0.0.new_response_outparam(result)
249 }
250
251 fn send_request(
252 &mut self,
253 request: hyper::Request<HyperOutgoingBody>,
254 config: OutgoingRequestConfig,
255 ) -> crate::HttpResult<HostFutureIncomingResponse> {
256 self.0.0.send_request(request, config)
257 }
258
259 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
260 self.0.0.is_forbidden_header(name)
261 }
262
263 fn outgoing_body_buffer_chunks(&mut self) -> usize {
264 self.0.0.outgoing_body_buffer_chunks()
265 }
266
267 fn outgoing_body_chunk_size(&mut self) -> usize {
268 self.0.0.outgoing_body_chunk_size()
269 }
270}
271
272pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
274 static FORBIDDEN_HEADERS: [HeaderName; 9] = [
275 hyper::header::CONNECTION,
276 HeaderName::from_static("keep-alive"),
277 hyper::header::PROXY_AUTHENTICATE,
278 hyper::header::PROXY_AUTHORIZATION,
279 HeaderName::from_static("proxy-connection"),
280 hyper::header::TRANSFER_ENCODING,
281 hyper::header::UPGRADE,
282 hyper::header::HOST,
283 HeaderName::from_static("http2-settings"),
284 ];
285
286 FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
287}
288
289pub(crate) fn remove_forbidden_headers(
291 view: &mut dyn WasiHttpView,
292 headers: &mut hyper::HeaderMap,
293) {
294 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
295 if is_forbidden_header(view, name) {
296 Some(name.clone())
297 } else {
298 None
299 }
300 }));
301
302 for name in forbidden_keys {
303 headers.remove(name);
304 }
305}
306
307pub struct OutgoingRequestConfig {
309 pub use_tls: bool,
311 pub connect_timeout: Duration,
313 pub first_byte_timeout: Duration,
315 pub between_bytes_timeout: Duration,
317}
318
319pub fn default_send_request(
324 request: hyper::Request<HyperOutgoingBody>,
325 config: OutgoingRequestConfig,
326) -> HostFutureIncomingResponse {
327 let handle = wasmtime_wasi::runtime::spawn(async move {
328 Ok(default_send_request_handler(request, config).await)
329 });
330 HostFutureIncomingResponse::pending(handle)
331}
332
333pub async fn default_send_request_handler(
338 mut request: hyper::Request<HyperOutgoingBody>,
339 OutgoingRequestConfig {
340 use_tls,
341 connect_timeout,
342 first_byte_timeout,
343 between_bytes_timeout,
344 }: OutgoingRequestConfig,
345) -> Result<IncomingResponse, types::ErrorCode> {
346 let authority = if let Some(authority) = request.uri().authority() {
347 if authority.port().is_some() {
348 authority.to_string()
349 } else {
350 let port = if use_tls { 443 } else { 80 };
351 format!("{}:{port}", authority.to_string())
352 }
353 } else {
354 return Err(types::ErrorCode::HttpRequestUriInvalid);
355 };
356 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
357 .await
358 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
359 .map_err(|e| match e.kind() {
360 std::io::ErrorKind::AddrNotAvailable => {
361 dns_error("address not available".to_string(), 0)
362 }
363
364 _ => {
365 if e.to_string()
366 .starts_with("failed to lookup address information")
367 {
368 dns_error("address not available".to_string(), 0)
369 } else {
370 types::ErrorCode::ConnectionRefused
371 }
372 }
373 })?;
374
375 let (mut sender, worker) = if use_tls {
376 use rustls::pki_types::ServerName;
377
378 let root_cert_store = rustls::RootCertStore {
380 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
381 };
382 let config = rustls::ClientConfig::builder()
383 .with_root_certificates(root_cert_store)
384 .with_no_client_auth();
385 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
386 let mut parts = authority.split(":");
387 let host = parts.next().unwrap_or(&authority);
388 let domain = ServerName::try_from(host)
389 .map_err(|e| {
390 tracing::warn!("dns lookup error: {e:?}");
391 dns_error("invalid dns name".to_string(), 0)
392 })?
393 .to_owned();
394 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
395 tracing::warn!("tls protocol error: {e:?}");
396 types::ErrorCode::TlsProtocolError
397 })?;
398 let stream = TokioIo::new(stream);
399
400 let (sender, conn) = timeout(
401 connect_timeout,
402 hyper::client::conn::http1::handshake(stream),
403 )
404 .await
405 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
406 .map_err(hyper_request_error)?;
407
408 let worker = wasmtime_wasi::runtime::spawn(async move {
409 match conn.await {
410 Ok(()) => {}
411 Err(e) => tracing::warn!("dropping error {e}"),
414 }
415 });
416
417 (sender, worker)
418 } else {
419 let tcp_stream = TokioIo::new(tcp_stream);
420 let (sender, conn) = timeout(
421 connect_timeout,
422 hyper::client::conn::http1::handshake(tcp_stream),
424 )
425 .await
426 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
427 .map_err(hyper_request_error)?;
428
429 let worker = wasmtime_wasi::runtime::spawn(async move {
430 match conn.await {
431 Ok(()) => {}
432 Err(e) => tracing::warn!("dropping error {e}"),
434 }
435 });
436
437 (sender, worker)
438 };
439
440 *request.uri_mut() = http::Uri::builder()
444 .path_and_query(
445 request
446 .uri()
447 .path_and_query()
448 .map(|p| p.as_str())
449 .unwrap_or("/"),
450 )
451 .build()
452 .expect("comes from valid request");
453
454 let resp = timeout(first_byte_timeout, sender.send_request(request))
455 .await
456 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
457 .map_err(hyper_request_error)?
458 .map(|body| body.map_err(hyper_request_error).boxed());
459
460 Ok(IncomingResponse {
461 resp,
462 worker: Some(worker),
463 between_bytes_timeout,
464 })
465}
466
467impl From<http::Method> for types::Method {
468 fn from(method: http::Method) -> Self {
469 if method == http::Method::GET {
470 types::Method::Get
471 } else if method == hyper::Method::HEAD {
472 types::Method::Head
473 } else if method == hyper::Method::POST {
474 types::Method::Post
475 } else if method == hyper::Method::PUT {
476 types::Method::Put
477 } else if method == hyper::Method::DELETE {
478 types::Method::Delete
479 } else if method == hyper::Method::CONNECT {
480 types::Method::Connect
481 } else if method == hyper::Method::OPTIONS {
482 types::Method::Options
483 } else if method == hyper::Method::TRACE {
484 types::Method::Trace
485 } else if method == hyper::Method::PATCH {
486 types::Method::Patch
487 } else {
488 types::Method::Other(method.to_string())
489 }
490 }
491}
492
493impl TryInto<http::Method> for types::Method {
494 type Error = http::method::InvalidMethod;
495
496 fn try_into(self) -> Result<http::Method, Self::Error> {
497 match self {
498 Method::Get => Ok(http::Method::GET),
499 Method::Head => Ok(http::Method::HEAD),
500 Method::Post => Ok(http::Method::POST),
501 Method::Put => Ok(http::Method::PUT),
502 Method::Delete => Ok(http::Method::DELETE),
503 Method::Connect => Ok(http::Method::CONNECT),
504 Method::Options => Ok(http::Method::OPTIONS),
505 Method::Trace => Ok(http::Method::TRACE),
506 Method::Patch => Ok(http::Method::PATCH),
507 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
508 }
509 }
510}
511
512#[derive(Debug)]
514pub struct HostIncomingRequest {
515 pub(crate) parts: http::request::Parts,
516 pub(crate) scheme: Scheme,
517 pub(crate) authority: String,
518 pub body: Option<HostIncomingBody>,
520}
521
522impl HostIncomingRequest {
523 pub fn new(
525 view: &mut dyn WasiHttpView,
526 mut parts: http::request::Parts,
527 scheme: Scheme,
528 body: Option<HostIncomingBody>,
529 ) -> anyhow::Result<Self> {
530 let authority = match parts.uri.authority() {
531 Some(authority) => authority.to_string(),
532 None => match parts.headers.get(http::header::HOST) {
533 Some(host) => host.to_str()?.to_string(),
534 None => bail!("invalid HTTP request missing authority in URI and host header"),
535 },
536 };
537
538 remove_forbidden_headers(view, &mut parts.headers);
539 Ok(Self {
540 parts,
541 authority,
542 scheme,
543 body,
544 })
545 }
546}
547
548pub struct HostResponseOutparam {
550 pub result:
552 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
553}
554
555pub struct HostOutgoingResponse {
557 pub status: http::StatusCode,
559 pub headers: FieldMap,
561 pub body: Option<HyperOutgoingBody>,
563}
564
565impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
566 type Error = http::Error;
567
568 fn try_from(
569 resp: HostOutgoingResponse,
570 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
571 use http_body_util::Empty;
572
573 let mut builder = hyper::Response::builder().status(resp.status);
574
575 *builder.headers_mut().unwrap() = resp.headers;
576
577 match resp.body {
578 Some(body) => builder.body(body),
579 None => builder.body(
580 Empty::<bytes::Bytes>::new()
581 .map_err(|_| unreachable!("Infallible error"))
582 .boxed(),
583 ),
584 }
585 }
586}
587
588#[derive(Debug)]
590pub struct HostOutgoingRequest {
591 pub method: Method,
593 pub scheme: Option<Scheme>,
595 pub authority: Option<String>,
597 pub path_with_query: Option<String>,
599 pub headers: FieldMap,
601 pub body: Option<HyperOutgoingBody>,
603}
604
605#[derive(Debug, Default)]
607pub struct HostRequestOptions {
608 pub connect_timeout: Option<std::time::Duration>,
610 pub first_byte_timeout: Option<std::time::Duration>,
612 pub between_bytes_timeout: Option<std::time::Duration>,
614}
615
616#[derive(Debug)]
618pub struct HostIncomingResponse {
619 pub status: u16,
621 pub headers: FieldMap,
623 pub body: Option<HostIncomingBody>,
625}
626
627#[derive(Debug)]
629pub enum HostFields {
630 Ref {
632 parent: u32,
634
635 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
641 },
642 Owned {
644 fields: FieldMap,
646 },
647}
648
649pub type FieldMap = hyper::HeaderMap;
651
652pub type FutureIncomingResponseHandle =
654 AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
655
656#[derive(Debug)]
658pub struct IncomingResponse {
659 pub resp: hyper::Response<HyperIncomingBody>,
661 pub worker: Option<AbortOnDropJoinHandle<()>>,
663 pub between_bytes_timeout: std::time::Duration,
665}
666
667#[derive(Debug)]
669pub enum HostFutureIncomingResponse {
670 Pending(FutureIncomingResponseHandle),
672 Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
676 Consumed,
678}
679
680impl HostFutureIncomingResponse {
681 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
683 Self::Pending(handle)
684 }
685
686 pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
688 Self::Ready(result)
689 }
690
691 pub fn is_ready(&self) -> bool {
693 matches!(self, Self::Ready(_))
694 }
695
696 pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
698 match self {
699 Self::Ready(res) => res,
700 Self::Pending(_) | Self::Consumed => {
701 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
702 }
703 }
704 }
705}
706
707#[async_trait::async_trait]
708impl Pollable for HostFutureIncomingResponse {
709 async fn ready(&mut self) {
710 if let Self::Pending(handle) = self {
711 *self = Self::Ready(handle.await);
712 }
713 }
714}