1use crate::io::TokioIo;
5use crate::{
6 bindings::http::types::{self, Method, Scheme},
7 body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
8 error::dns_error,
9 hyper_request_error,
10};
11use anyhow::bail;
12use bytes::Bytes;
13use http_body_util::BodyExt;
14use hyper::body::Body;
15use hyper::header::HeaderName;
16use std::any::Any;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::time::timeout;
20use wasmtime::component::{Resource, ResourceTable};
21use wasmtime_wasi::p2::{IoImpl, IoView, Pollable};
22use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
23
24#[derive(Debug)]
26pub struct WasiHttpCtx {
27 _priv: (),
28}
29
30impl WasiHttpCtx {
31 pub fn new() -> Self {
33 Self { _priv: () }
34 }
35}
36
37pub trait WasiHttpView: IoView {
79 fn ctx(&mut self) -> &mut WasiHttpCtx;
81
82 fn new_incoming_request<B>(
84 &mut self,
85 scheme: Scheme,
86 req: hyper::Request<B>,
87 ) -> wasmtime::Result<Resource<HostIncomingRequest>>
88 where
89 B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
90 Self: Sized,
91 {
92 let (parts, body) = req.into_parts();
93 let body = body.map_err(crate::hyper_response_error).boxed();
94 let body = HostIncomingBody::new(
95 body,
96 std::time::Duration::from_millis(600 * 1000),
98 );
99 let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
100 Ok(self.table().push(incoming_req)?)
101 }
102
103 fn new_response_outparam(
105 &mut self,
106 result: tokio::sync::oneshot::Sender<
107 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
108 >,
109 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
110 let id = self.table().push(HostResponseOutparam { result })?;
111 Ok(id)
112 }
113
114 fn send_request(
116 &mut self,
117 request: hyper::Request<HyperOutgoingBody>,
118 config: OutgoingRequestConfig,
119 ) -> crate::HttpResult<HostFutureIncomingResponse> {
120 Ok(default_send_request(request, config))
121 }
122
123 fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
125 false
126 }
127
128 fn outgoing_body_buffer_chunks(&mut self) -> usize {
132 DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
133 }
134
135 fn outgoing_body_chunk_size(&mut self) -> usize {
138 DEFAULT_OUTGOING_BODY_CHUNK_SIZE
139 }
140}
141
142pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
144pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
146
147impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
148 fn ctx(&mut self) -> &mut WasiHttpCtx {
149 T::ctx(self)
150 }
151
152 fn new_response_outparam(
153 &mut self,
154 result: tokio::sync::oneshot::Sender<
155 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
156 >,
157 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
158 T::new_response_outparam(self, result)
159 }
160
161 fn send_request(
162 &mut self,
163 request: hyper::Request<HyperOutgoingBody>,
164 config: OutgoingRequestConfig,
165 ) -> crate::HttpResult<HostFutureIncomingResponse> {
166 T::send_request(self, request, config)
167 }
168
169 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
170 T::is_forbidden_header(self, name)
171 }
172
173 fn outgoing_body_buffer_chunks(&mut self) -> usize {
174 T::outgoing_body_buffer_chunks(self)
175 }
176
177 fn outgoing_body_chunk_size(&mut self) -> usize {
178 T::outgoing_body_chunk_size(self)
179 }
180}
181
182impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
183 fn ctx(&mut self) -> &mut WasiHttpCtx {
184 T::ctx(self)
185 }
186
187 fn new_response_outparam(
188 &mut self,
189 result: tokio::sync::oneshot::Sender<
190 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
191 >,
192 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
193 T::new_response_outparam(self, result)
194 }
195
196 fn send_request(
197 &mut self,
198 request: hyper::Request<HyperOutgoingBody>,
199 config: OutgoingRequestConfig,
200 ) -> crate::HttpResult<HostFutureIncomingResponse> {
201 T::send_request(self, request, config)
202 }
203
204 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
205 T::is_forbidden_header(self, name)
206 }
207
208 fn outgoing_body_buffer_chunks(&mut self) -> usize {
209 T::outgoing_body_buffer_chunks(self)
210 }
211
212 fn outgoing_body_chunk_size(&mut self) -> usize {
213 T::outgoing_body_chunk_size(self)
214 }
215}
216
217#[repr(transparent)]
230pub struct WasiHttpImpl<T>(pub IoImpl<T>);
231
232impl<T: IoView> IoView for WasiHttpImpl<T> {
233 fn table(&mut self) -> &mut ResourceTable {
234 T::table(&mut self.0 .0)
235 }
236}
237impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
238 fn ctx(&mut self) -> &mut WasiHttpCtx {
239 self.0 .0.ctx()
240 }
241
242 fn new_response_outparam(
243 &mut self,
244 result: tokio::sync::oneshot::Sender<
245 Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
246 >,
247 ) -> wasmtime::Result<Resource<HostResponseOutparam>> {
248 self.0 .0.new_response_outparam(result)
249 }
250
251 fn send_request(
252 &mut self,
253 request: hyper::Request<HyperOutgoingBody>,
254 config: OutgoingRequestConfig,
255 ) -> crate::HttpResult<HostFutureIncomingResponse> {
256 self.0 .0.send_request(request, config)
257 }
258
259 fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
260 self.0 .0.is_forbidden_header(name)
261 }
262
263 fn outgoing_body_buffer_chunks(&mut self) -> usize {
264 self.0 .0.outgoing_body_buffer_chunks()
265 }
266
267 fn outgoing_body_chunk_size(&mut self) -> usize {
268 self.0 .0.outgoing_body_chunk_size()
269 }
270}
271
272pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
274 static FORBIDDEN_HEADERS: [HeaderName; 10] = [
275 hyper::header::CONNECTION,
276 HeaderName::from_static("keep-alive"),
277 hyper::header::PROXY_AUTHENTICATE,
278 hyper::header::PROXY_AUTHORIZATION,
279 HeaderName::from_static("proxy-connection"),
280 hyper::header::TE,
281 hyper::header::TRANSFER_ENCODING,
282 hyper::header::UPGRADE,
283 hyper::header::HOST,
284 HeaderName::from_static("http2-settings"),
285 ];
286
287 FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
288}
289
290pub(crate) fn remove_forbidden_headers(
292 view: &mut dyn WasiHttpView,
293 headers: &mut hyper::HeaderMap,
294) {
295 let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
296 if is_forbidden_header(view, name) {
297 Some(name.clone())
298 } else {
299 None
300 }
301 }));
302
303 for name in forbidden_keys {
304 headers.remove(name);
305 }
306}
307
308pub struct OutgoingRequestConfig {
310 pub use_tls: bool,
312 pub connect_timeout: Duration,
314 pub first_byte_timeout: Duration,
316 pub between_bytes_timeout: Duration,
318}
319
320pub fn default_send_request(
325 request: hyper::Request<HyperOutgoingBody>,
326 config: OutgoingRequestConfig,
327) -> HostFutureIncomingResponse {
328 let handle = wasmtime_wasi::runtime::spawn(async move {
329 Ok(default_send_request_handler(request, config).await)
330 });
331 HostFutureIncomingResponse::pending(handle)
332}
333
334pub async fn default_send_request_handler(
339 mut request: hyper::Request<HyperOutgoingBody>,
340 OutgoingRequestConfig {
341 use_tls,
342 connect_timeout,
343 first_byte_timeout,
344 between_bytes_timeout,
345 }: OutgoingRequestConfig,
346) -> Result<IncomingResponse, types::ErrorCode> {
347 let authority = if let Some(authority) = request.uri().authority() {
348 if authority.port().is_some() {
349 authority.to_string()
350 } else {
351 let port = if use_tls { 443 } else { 80 };
352 format!("{}:{port}", authority.to_string())
353 }
354 } else {
355 return Err(types::ErrorCode::HttpRequestUriInvalid);
356 };
357 let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
358 .await
359 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
360 .map_err(|e| match e.kind() {
361 std::io::ErrorKind::AddrNotAvailable => {
362 dns_error("address not available".to_string(), 0)
363 }
364
365 _ => {
366 if e.to_string()
367 .starts_with("failed to lookup address information")
368 {
369 dns_error("address not available".to_string(), 0)
370 } else {
371 types::ErrorCode::ConnectionRefused
372 }
373 }
374 })?;
375
376 let (mut sender, worker) = if use_tls {
377 use rustls::pki_types::ServerName;
378
379 let root_cert_store = rustls::RootCertStore {
381 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
382 };
383 let config = rustls::ClientConfig::builder()
384 .with_root_certificates(root_cert_store)
385 .with_no_client_auth();
386 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
387 let mut parts = authority.split(":");
388 let host = parts.next().unwrap_or(&authority);
389 let domain = ServerName::try_from(host)
390 .map_err(|e| {
391 tracing::warn!("dns lookup error: {e:?}");
392 dns_error("invalid dns name".to_string(), 0)
393 })?
394 .to_owned();
395 let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
396 tracing::warn!("tls protocol error: {e:?}");
397 types::ErrorCode::TlsProtocolError
398 })?;
399 let stream = TokioIo::new(stream);
400
401 let (sender, conn) = timeout(
402 connect_timeout,
403 hyper::client::conn::http1::handshake(stream),
404 )
405 .await
406 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
407 .map_err(hyper_request_error)?;
408
409 let worker = wasmtime_wasi::runtime::spawn(async move {
410 match conn.await {
411 Ok(()) => {}
412 Err(e) => tracing::warn!("dropping error {e}"),
415 }
416 });
417
418 (sender, worker)
419 } else {
420 let tcp_stream = TokioIo::new(tcp_stream);
421 let (sender, conn) = timeout(
422 connect_timeout,
423 hyper::client::conn::http1::handshake(tcp_stream),
425 )
426 .await
427 .map_err(|_| types::ErrorCode::ConnectionTimeout)?
428 .map_err(hyper_request_error)?;
429
430 let worker = wasmtime_wasi::runtime::spawn(async move {
431 match conn.await {
432 Ok(()) => {}
433 Err(e) => tracing::warn!("dropping error {e}"),
435 }
436 });
437
438 (sender, worker)
439 };
440
441 *request.uri_mut() = http::Uri::builder()
445 .path_and_query(
446 request
447 .uri()
448 .path_and_query()
449 .map(|p| p.as_str())
450 .unwrap_or("/"),
451 )
452 .build()
453 .expect("comes from valid request");
454
455 let resp = timeout(first_byte_timeout, sender.send_request(request))
456 .await
457 .map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
458 .map_err(hyper_request_error)?
459 .map(|body| body.map_err(hyper_request_error).boxed());
460
461 Ok(IncomingResponse {
462 resp,
463 worker: Some(worker),
464 between_bytes_timeout,
465 })
466}
467
468impl From<http::Method> for types::Method {
469 fn from(method: http::Method) -> Self {
470 if method == http::Method::GET {
471 types::Method::Get
472 } else if method == hyper::Method::HEAD {
473 types::Method::Head
474 } else if method == hyper::Method::POST {
475 types::Method::Post
476 } else if method == hyper::Method::PUT {
477 types::Method::Put
478 } else if method == hyper::Method::DELETE {
479 types::Method::Delete
480 } else if method == hyper::Method::CONNECT {
481 types::Method::Connect
482 } else if method == hyper::Method::OPTIONS {
483 types::Method::Options
484 } else if method == hyper::Method::TRACE {
485 types::Method::Trace
486 } else if method == hyper::Method::PATCH {
487 types::Method::Patch
488 } else {
489 types::Method::Other(method.to_string())
490 }
491 }
492}
493
494impl TryInto<http::Method> for types::Method {
495 type Error = http::method::InvalidMethod;
496
497 fn try_into(self) -> Result<http::Method, Self::Error> {
498 match self {
499 Method::Get => Ok(http::Method::GET),
500 Method::Head => Ok(http::Method::HEAD),
501 Method::Post => Ok(http::Method::POST),
502 Method::Put => Ok(http::Method::PUT),
503 Method::Delete => Ok(http::Method::DELETE),
504 Method::Connect => Ok(http::Method::CONNECT),
505 Method::Options => Ok(http::Method::OPTIONS),
506 Method::Trace => Ok(http::Method::TRACE),
507 Method::Patch => Ok(http::Method::PATCH),
508 Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
509 }
510 }
511}
512
513#[derive(Debug)]
515pub struct HostIncomingRequest {
516 pub(crate) parts: http::request::Parts,
517 pub(crate) scheme: Scheme,
518 pub(crate) authority: String,
519 pub body: Option<HostIncomingBody>,
521}
522
523impl HostIncomingRequest {
524 pub fn new(
526 view: &mut dyn WasiHttpView,
527 mut parts: http::request::Parts,
528 scheme: Scheme,
529 body: Option<HostIncomingBody>,
530 ) -> anyhow::Result<Self> {
531 let authority = match parts.uri.authority() {
532 Some(authority) => authority.to_string(),
533 None => match parts.headers.get(http::header::HOST) {
534 Some(host) => host.to_str()?.to_string(),
535 None => bail!("invalid HTTP request missing authority in URI and host header"),
536 },
537 };
538
539 remove_forbidden_headers(view, &mut parts.headers);
540 Ok(Self {
541 parts,
542 authority,
543 scheme,
544 body,
545 })
546 }
547}
548
549pub struct HostResponseOutparam {
551 pub result:
553 tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
554}
555
556pub struct HostOutgoingResponse {
558 pub status: http::StatusCode,
560 pub headers: FieldMap,
562 pub body: Option<HyperOutgoingBody>,
564}
565
566impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
567 type Error = http::Error;
568
569 fn try_from(
570 resp: HostOutgoingResponse,
571 ) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
572 use http_body_util::Empty;
573
574 let mut builder = hyper::Response::builder().status(resp.status);
575
576 *builder.headers_mut().unwrap() = resp.headers;
577
578 match resp.body {
579 Some(body) => builder.body(body),
580 None => builder.body(
581 Empty::<bytes::Bytes>::new()
582 .map_err(|_| unreachable!("Infallible error"))
583 .boxed(),
584 ),
585 }
586 }
587}
588
589#[derive(Debug)]
591pub struct HostOutgoingRequest {
592 pub method: Method,
594 pub scheme: Option<Scheme>,
596 pub authority: Option<String>,
598 pub path_with_query: Option<String>,
600 pub headers: FieldMap,
602 pub body: Option<HyperOutgoingBody>,
604}
605
606#[derive(Debug, Default)]
608pub struct HostRequestOptions {
609 pub connect_timeout: Option<std::time::Duration>,
611 pub first_byte_timeout: Option<std::time::Duration>,
613 pub between_bytes_timeout: Option<std::time::Duration>,
615}
616
617#[derive(Debug)]
619pub struct HostIncomingResponse {
620 pub status: u16,
622 pub headers: FieldMap,
624 pub body: Option<HostIncomingBody>,
626}
627
628#[derive(Debug)]
630pub enum HostFields {
631 Ref {
633 parent: u32,
635
636 get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
642 },
643 Owned {
645 fields: FieldMap,
647 },
648}
649
650pub type FieldMap = hyper::HeaderMap;
652
653pub type FutureIncomingResponseHandle =
655 AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
656
657#[derive(Debug)]
659pub struct IncomingResponse {
660 pub resp: hyper::Response<HyperIncomingBody>,
662 pub worker: Option<AbortOnDropJoinHandle<()>>,
664 pub between_bytes_timeout: std::time::Duration,
666}
667
668#[derive(Debug)]
670pub enum HostFutureIncomingResponse {
671 Pending(FutureIncomingResponseHandle),
673 Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
677 Consumed,
679}
680
681impl HostFutureIncomingResponse {
682 pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
684 Self::Pending(handle)
685 }
686
687 pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
689 Self::Ready(result)
690 }
691
692 pub fn is_ready(&self) -> bool {
694 matches!(self, Self::Ready(_))
695 }
696
697 pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
699 match self {
700 Self::Ready(res) => res,
701 Self::Pending(_) | Self::Consumed => {
702 panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
703 }
704 }
705 }
706}
707
708#[async_trait::async_trait]
709impl Pollable for HostFutureIncomingResponse {
710 async fn ready(&mut self) {
711 if let Self::Pending(handle) = self {
712 *self = Self::Ready(handle.await);
713 }
714 }
715}