wasmtime_wasi_http/p3/host/
handler.rs

1use crate::get_content_length;
2use crate::p3::bindings::http::handler::{Host, HostWithStore};
3use crate::p3::bindings::http::types::{ErrorCode, Request, Response};
4use crate::p3::body::{Body, GuestBody};
5use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView};
6use anyhow::Context as _;
7use bytes::Bytes;
8use core::pin::Pin;
9use core::task::{Context, Poll, Waker, ready};
10use http::header::HOST;
11use http::{HeaderValue, Uri};
12use http_body_util::BodyExt as _;
13use std::sync::Arc;
14use tokio::sync::oneshot;
15use tokio::task::{self, JoinHandle};
16use tracing::debug;
17use wasmtime::component::{Accessor, Resource};
18
19/// A wrapper around [`JoinHandle`], which will [`JoinHandle::abort`] the task
20/// when dropped
21struct AbortOnDropJoinHandle(JoinHandle<()>);
22
23impl Drop for AbortOnDropJoinHandle {
24    fn drop(&mut self) {
25        self.0.abort();
26    }
27}
28
29/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it
30struct BodyWithState<T, U> {
31    body: T,
32    _state: U,
33}
34
35impl<T, U> http_body::Body for BodyWithState<T, U>
36where
37    T: http_body::Body + Unpin,
38    U: Unpin,
39{
40    type Data = T::Data;
41    type Error = T::Error;
42
43    #[inline]
44    fn poll_frame(
45        self: Pin<&mut Self>,
46        cx: &mut Context<'_>,
47    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
48        Pin::new(&mut self.get_mut().body).poll_frame(cx)
49    }
50
51    #[inline]
52    fn is_end_stream(&self) -> bool {
53        self.body.is_end_stream()
54    }
55
56    #[inline]
57    fn size_hint(&self) -> http_body::SizeHint {
58        self.body.size_hint()
59    }
60}
61
62/// A wrapper around [http_body::Body], which validates `Content-Length`
63struct BodyWithContentLength<T, E> {
64    body: T,
65    error_tx: Option<oneshot::Sender<E>>,
66    make_error: fn(Option<u64>) -> E,
67    /// Limit of bytes to be sent
68    limit: u64,
69    /// Number of bytes sent
70    sent: u64,
71}
72
73impl<T, E> BodyWithContentLength<T, E> {
74    /// Sends the error constructed by [Self::make_error] on [Self::error_tx].
75    /// Does nothing if an error has already been sent on [Self::error_tx].
76    fn send_error<V>(&mut self, sent: Option<u64>) -> Poll<Option<Result<V, E>>> {
77        if let Some(error_tx) = self.error_tx.take() {
78            _ = error_tx.send((self.make_error)(sent));
79        }
80        Poll::Ready(Some(Err((self.make_error)(sent))))
81    }
82}
83
84impl<T, E> http_body::Body for BodyWithContentLength<T, E>
85where
86    T: http_body::Body<Data = Bytes, Error = E> + Unpin,
87{
88    type Data = T::Data;
89    type Error = T::Error;
90
91    #[inline]
92    fn poll_frame(
93        mut self: Pin<&mut Self>,
94        cx: &mut Context<'_>,
95    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
96        match ready!(Pin::new(&mut self.as_mut().body).poll_frame(cx)) {
97            Some(Ok(frame)) => {
98                let Some(data) = frame.data_ref() else {
99                    return Poll::Ready(Some(Ok(frame)));
100                };
101                let Ok(sent) = data.len().try_into() else {
102                    return self.send_error(None);
103                };
104                let Some(sent) = self.sent.checked_add(sent) else {
105                    return self.send_error(None);
106                };
107                if sent > self.limit {
108                    return self.send_error(Some(sent));
109                }
110                self.sent = sent;
111                Poll::Ready(Some(Ok(frame)))
112            }
113            Some(Err(err)) => Poll::Ready(Some(Err(err))),
114            None if self.limit != self.sent => {
115                // short write
116                let sent = self.sent;
117                self.send_error(Some(sent))
118            }
119            None => Poll::Ready(None),
120        }
121    }
122
123    #[inline]
124    fn is_end_stream(&self) -> bool {
125        self.body.is_end_stream()
126    }
127
128    #[inline]
129    fn size_hint(&self) -> http_body::SizeHint {
130        let n = self.limit.saturating_sub(self.sent);
131        let mut hint = self.body.size_hint();
132        if hint.lower() >= n {
133            hint.set_exact(n)
134        } else if let Some(max) = hint.upper() {
135            hint.set_upper(n.min(max))
136        } else {
137            hint.set_upper(n)
138        }
139        hint
140    }
141}
142
143trait BodyExt {
144    fn with_state<T>(self, state: T) -> BodyWithState<Self, T>
145    where
146        Self: Sized,
147    {
148        BodyWithState {
149            body: self,
150            _state: state,
151        }
152    }
153
154    fn with_content_length<E>(
155        self,
156        limit: u64,
157        error_tx: oneshot::Sender<E>,
158        make_error: fn(Option<u64>) -> E,
159    ) -> BodyWithContentLength<Self, E>
160    where
161        Self: Sized,
162    {
163        BodyWithContentLength {
164            body: self,
165            error_tx: Some(error_tx),
166            make_error,
167            limit,
168            sent: 0,
169        }
170    }
171}
172
173impl<T> BodyExt for T {}
174
175async fn io_task_result(
176    rx: oneshot::Receiver<(
177        Arc<AbortOnDropJoinHandle>,
178        oneshot::Receiver<Result<(), ErrorCode>>,
179    )>,
180) -> Result<(), ErrorCode> {
181    let Ok((_io, io_result_rx)) = rx.await else {
182        return Ok(());
183    };
184    io_result_rx.await.unwrap_or(Ok(()))
185}
186
187impl HostWithStore for WasiHttp {
188    async fn handle<T>(
189        store: &Accessor<T, Self>,
190        req: Resource<Request>,
191    ) -> HttpResult<Resource<Response>> {
192        // A handle to the I/O task, if spawned, will be sent on this channel
193        // and kept as part of request body state
194        let (io_task_tx, io_task_rx) = oneshot::channel();
195
196        // A handle to the I/O task, if spawned, will be sent on this channel
197        // along with the result receiver
198        let (io_result_tx, io_result_rx) = oneshot::channel();
199
200        // Response processing result will be sent on this channel
201        let (res_result_tx, res_result_rx) = oneshot::channel();
202
203        let getter = store.getter();
204        let fut = store.with(|mut store| {
205            let WasiHttpCtxView { table, .. } = store.get();
206            let Request {
207                method,
208                scheme,
209                authority,
210                path_with_query,
211                headers,
212                options,
213                body,
214            } = table
215                .delete(req)
216                .context("failed to delete request from table")
217                .map_err(HttpError::trap)?;
218            // `Content-Length` header value is validated in `fields` implementation
219            let content_length = match get_content_length(&headers) {
220                Ok(content_length) => content_length,
221                Err(err) => {
222                    body.drop(&mut store);
223                    return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into());
224                }
225            };
226            let mut headers = Arc::unwrap_or_clone(headers);
227            let body = match body {
228                Body::Guest {
229                    contents_rx,
230                    trailers_rx,
231                    result_tx,
232                } => GuestBody::new(
233                    &mut store,
234                    contents_rx,
235                    trailers_rx,
236                    result_tx,
237                    io_task_result(io_result_rx),
238                    content_length,
239                    ErrorCode::HttpRequestBodySize,
240                    getter,
241                )
242                .with_state(io_task_rx)
243                .boxed(),
244                Body::Host { body, result_tx } => {
245                    if let Some(limit) = content_length {
246                        let (http_result_tx, http_result_rx) = oneshot::channel();
247                        _ = result_tx.send(Box::new(async move {
248                            if let Ok(err) = http_result_rx.await {
249                                return Err(err);
250                            };
251                            io_task_result(io_result_rx).await
252                        }));
253                        body.with_content_length(
254                            limit,
255                            http_result_tx,
256                            ErrorCode::HttpRequestBodySize,
257                        )
258                        .with_state(io_task_rx)
259                        .boxed()
260                    } else {
261                        _ = result_tx.send(Box::new(io_task_result(io_result_rx)));
262                        body.with_state(io_task_rx).boxed()
263                    }
264                }
265            };
266
267            let WasiHttpCtxView { ctx, .. } = store.get();
268            if ctx.set_host_header() {
269                let host = if let Some(authority) = authority.as_ref() {
270                    HeaderValue::try_from(authority.as_str())
271                        .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?
272                } else {
273                    HeaderValue::from_static("")
274                };
275                headers.insert(HOST, host);
276            }
277            let scheme = match scheme {
278                None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?,
279                Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme,
280                Some(..) => return Err(ErrorCode::HttpProtocolError.into()),
281            };
282            let mut uri = Uri::builder().scheme(scheme);
283            if let Some(authority) = authority {
284                uri = uri.authority(authority)
285            };
286            if let Some(path_with_query) = path_with_query {
287                uri = uri.path_and_query(path_with_query)
288            };
289            let uri = uri.build().map_err(|err| {
290                debug!(?err, "failed to build request URI");
291                ErrorCode::HttpRequestUriInvalid
292            })?;
293            let mut req = http::Request::builder();
294            *req.headers_mut().unwrap() = headers;
295            let req = req
296                .method(method)
297                .uri(uri)
298                .body(body)
299                .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
300            HttpResult::Ok(store.get().ctx.send_request(
301                req,
302                options.as_deref().copied(),
303                Box::new(async {
304                    // Forward the response processing result to `WasiHttpCtx` implementation
305                    let Ok(fut) = res_result_rx.await else {
306                        return Ok(());
307                    };
308                    Box::into_pin(fut).await
309                }),
310            ))
311        })?;
312        let (res, io) = Box::into_pin(fut).await?;
313        let (
314            http::response::Parts {
315                status, headers, ..
316            },
317            body,
318        ) = res.into_parts();
319
320        let mut io = Box::into_pin(io);
321        let body = match io.as_mut().poll(&mut Context::from_waker(Waker::noop()))? {
322            Poll::Ready(()) => body,
323            Poll::Pending => {
324                // I/O driver still needs to be polled, spawn a task and send handles to it
325                let (tx, rx) = oneshot::channel();
326                let io = task::spawn(async move {
327                    let res = io.await;
328                    debug!(?res, "`send_request` I/O future finished");
329                    _ = tx.send(res);
330                });
331                let io = Arc::new(AbortOnDropJoinHandle(io));
332                _ = io_result_tx.send((Arc::clone(&io), rx));
333                _ = io_task_tx.send(Arc::clone(&io));
334                body.with_state(io).boxed()
335            }
336        };
337        let res = Response {
338            status,
339            headers: Arc::new(headers),
340            body: Body::Host {
341                body,
342                result_tx: res_result_tx,
343            },
344        };
345        store.with(|mut store| {
346            store
347                .get()
348                .table
349                .push(res)
350                .context("failed to push response to table")
351                .map_err(HttpError::trap)
352        })
353    }
354}
355
356impl Host for WasiHttpCtxView<'_> {}