wasmtime_wasi_http/p3/host/
handler.rs1use crate::get_content_length;
2use crate::p3::bindings::http::handler::{Host, HostWithStore};
3use crate::p3::bindings::http::types::{ErrorCode, Request, Response};
4use crate::p3::body::{Body, GuestBody};
5use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView};
6use anyhow::Context as _;
7use bytes::Bytes;
8use core::pin::Pin;
9use core::task::{Context, Poll, Waker, ready};
10use http::header::HOST;
11use http::{HeaderValue, Uri};
12use http_body_util::BodyExt as _;
13use std::sync::Arc;
14use tokio::sync::oneshot;
15use tokio::task::{self, JoinHandle};
16use tracing::debug;
17use wasmtime::component::{Accessor, Resource};
18
19struct AbortOnDropJoinHandle(JoinHandle<()>);
22
23impl Drop for AbortOnDropJoinHandle {
24 fn drop(&mut self) {
25 self.0.abort();
26 }
27}
28
29struct BodyWithState<T, U> {
31 body: T,
32 _state: U,
33}
34
35impl<T, U> http_body::Body for BodyWithState<T, U>
36where
37 T: http_body::Body + Unpin,
38 U: Unpin,
39{
40 type Data = T::Data;
41 type Error = T::Error;
42
43 #[inline]
44 fn poll_frame(
45 self: Pin<&mut Self>,
46 cx: &mut Context<'_>,
47 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
48 Pin::new(&mut self.get_mut().body).poll_frame(cx)
49 }
50
51 #[inline]
52 fn is_end_stream(&self) -> bool {
53 self.body.is_end_stream()
54 }
55
56 #[inline]
57 fn size_hint(&self) -> http_body::SizeHint {
58 self.body.size_hint()
59 }
60}
61
62struct BodyWithContentLength<T, E> {
64 body: T,
65 error_tx: Option<oneshot::Sender<E>>,
66 make_error: fn(Option<u64>) -> E,
67 limit: u64,
69 sent: u64,
71}
72
73impl<T, E> BodyWithContentLength<T, E> {
74 fn send_error<V>(&mut self, sent: Option<u64>) -> Poll<Option<Result<V, E>>> {
77 if let Some(error_tx) = self.error_tx.take() {
78 _ = error_tx.send((self.make_error)(sent));
79 }
80 Poll::Ready(Some(Err((self.make_error)(sent))))
81 }
82}
83
84impl<T, E> http_body::Body for BodyWithContentLength<T, E>
85where
86 T: http_body::Body<Data = Bytes, Error = E> + Unpin,
87{
88 type Data = T::Data;
89 type Error = T::Error;
90
91 #[inline]
92 fn poll_frame(
93 mut self: Pin<&mut Self>,
94 cx: &mut Context<'_>,
95 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
96 match ready!(Pin::new(&mut self.as_mut().body).poll_frame(cx)) {
97 Some(Ok(frame)) => {
98 let Some(data) = frame.data_ref() else {
99 return Poll::Ready(Some(Ok(frame)));
100 };
101 let Ok(sent) = data.len().try_into() else {
102 return self.send_error(None);
103 };
104 let Some(sent) = self.sent.checked_add(sent) else {
105 return self.send_error(None);
106 };
107 if sent > self.limit {
108 return self.send_error(Some(sent));
109 }
110 self.sent = sent;
111 Poll::Ready(Some(Ok(frame)))
112 }
113 Some(Err(err)) => Poll::Ready(Some(Err(err))),
114 None if self.limit != self.sent => {
115 let sent = self.sent;
117 self.send_error(Some(sent))
118 }
119 None => Poll::Ready(None),
120 }
121 }
122
123 #[inline]
124 fn is_end_stream(&self) -> bool {
125 self.body.is_end_stream()
126 }
127
128 #[inline]
129 fn size_hint(&self) -> http_body::SizeHint {
130 let n = self.limit.saturating_sub(self.sent);
131 let mut hint = self.body.size_hint();
132 if hint.lower() >= n {
133 hint.set_exact(n)
134 } else if let Some(max) = hint.upper() {
135 hint.set_upper(n.min(max))
136 } else {
137 hint.set_upper(n)
138 }
139 hint
140 }
141}
142
143trait BodyExt {
144 fn with_state<T>(self, state: T) -> BodyWithState<Self, T>
145 where
146 Self: Sized,
147 {
148 BodyWithState {
149 body: self,
150 _state: state,
151 }
152 }
153
154 fn with_content_length<E>(
155 self,
156 limit: u64,
157 error_tx: oneshot::Sender<E>,
158 make_error: fn(Option<u64>) -> E,
159 ) -> BodyWithContentLength<Self, E>
160 where
161 Self: Sized,
162 {
163 BodyWithContentLength {
164 body: self,
165 error_tx: Some(error_tx),
166 make_error,
167 limit,
168 sent: 0,
169 }
170 }
171}
172
173impl<T> BodyExt for T {}
174
175async fn io_task_result(
176 rx: oneshot::Receiver<(
177 Arc<AbortOnDropJoinHandle>,
178 oneshot::Receiver<Result<(), ErrorCode>>,
179 )>,
180) -> Result<(), ErrorCode> {
181 let Ok((_io, io_result_rx)) = rx.await else {
182 return Ok(());
183 };
184 io_result_rx.await.unwrap_or(Ok(()))
185}
186
187impl HostWithStore for WasiHttp {
188 async fn handle<T>(
189 store: &Accessor<T, Self>,
190 req: Resource<Request>,
191 ) -> HttpResult<Resource<Response>> {
192 let (io_task_tx, io_task_rx) = oneshot::channel();
195
196 let (io_result_tx, io_result_rx) = oneshot::channel();
199
200 let (res_result_tx, res_result_rx) = oneshot::channel();
202
203 let getter = store.getter();
204 let fut = store.with(|mut store| {
205 let WasiHttpCtxView { table, .. } = store.get();
206 let Request {
207 method,
208 scheme,
209 authority,
210 path_with_query,
211 headers,
212 options,
213 body,
214 } = table
215 .delete(req)
216 .context("failed to delete request from table")
217 .map_err(HttpError::trap)?;
218 let content_length = match get_content_length(&headers) {
220 Ok(content_length) => content_length,
221 Err(err) => {
222 body.drop(&mut store);
223 return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into());
224 }
225 };
226 let mut headers = Arc::unwrap_or_clone(headers);
227 let body = match body {
228 Body::Guest {
229 contents_rx,
230 trailers_rx,
231 result_tx,
232 } => GuestBody::new(
233 &mut store,
234 contents_rx,
235 trailers_rx,
236 result_tx,
237 io_task_result(io_result_rx),
238 content_length,
239 ErrorCode::HttpRequestBodySize,
240 getter,
241 )
242 .with_state(io_task_rx)
243 .boxed(),
244 Body::Host { body, result_tx } => {
245 if let Some(limit) = content_length {
246 let (http_result_tx, http_result_rx) = oneshot::channel();
247 _ = result_tx.send(Box::new(async move {
248 if let Ok(err) = http_result_rx.await {
249 return Err(err);
250 };
251 io_task_result(io_result_rx).await
252 }));
253 body.with_content_length(
254 limit,
255 http_result_tx,
256 ErrorCode::HttpRequestBodySize,
257 )
258 .with_state(io_task_rx)
259 .boxed()
260 } else {
261 _ = result_tx.send(Box::new(io_task_result(io_result_rx)));
262 body.with_state(io_task_rx).boxed()
263 }
264 }
265 };
266
267 let WasiHttpCtxView { ctx, .. } = store.get();
268 if ctx.set_host_header() {
269 let host = if let Some(authority) = authority.as_ref() {
270 HeaderValue::try_from(authority.as_str())
271 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
272 } else {
273 HeaderValue::from_static("")
274 };
275 headers.insert(HOST, host);
276 }
277 let scheme = match scheme {
278 None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
279 Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme,
280 Some(..) => return Err(ErrorCode::HttpProtocolError.into()),
281 };
282 let mut uri = Uri::builder().scheme(scheme);
283 if let Some(authority) = authority {
284 uri = uri.authority(authority)
285 };
286 if let Some(path_with_query) = path_with_query {
287 uri = uri.path_and_query(path_with_query)
288 };
289 let uri = uri.build().map_err(|err| {
290 debug!(?err, "failed to build request URI");
291 ErrorCode::HttpRequestUriInvalid
292 })?;
293 let mut req = http::Request::builder();
294 *req.headers_mut().unwrap() = headers;
295 let req = req
296 .method(method)
297 .uri(uri)
298 .body(body)
299 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
300 HttpResult::Ok(store.get().ctx.send_request(
301 req,
302 options.as_deref().copied(),
303 Box::new(async {
304 let Ok(fut) = res_result_rx.await else {
306 return Ok(());
307 };
308 Box::into_pin(fut).await
309 }),
310 ))
311 })?;
312 let (res, io) = Box::into_pin(fut).await?;
313 let (
314 http::response::Parts {
315 status, headers, ..
316 },
317 body,
318 ) = res.into_parts();
319
320 let mut io = Box::into_pin(io);
321 let body = match io.as_mut().poll(&mut Context::from_waker(Waker::noop()))? {
322 Poll::Ready(()) => body,
323 Poll::Pending => {
324 let (tx, rx) = oneshot::channel();
326 let io = task::spawn(async move {
327 let res = io.await;
328 debug!(?res, "`send_request` I/O future finished");
329 _ = tx.send(res);
330 });
331 let io = Arc::new(AbortOnDropJoinHandle(io));
332 _ = io_result_tx.send((Arc::clone(&io), rx));
333 _ = io_task_tx.send(Arc::clone(&io));
334 body.with_state(io).boxed()
335 }
336 };
337 let res = Response {
338 status,
339 headers: Arc::new(headers),
340 body: Body::Host {
341 body,
342 result_tx: res_result_tx,
343 },
344 };
345 store.with(|mut store| {
346 store
347 .get()
348 .table
349 .push(res)
350 .context("failed to push response to table")
351 .map_err(HttpError::trap)
352 })
353 }
354}
355
356impl Host for WasiHttpCtxView<'_> {}