1use crate::p3::bindings::http::types::ErrorCode;
2use crate::p3::body::Body;
3use bytes::Bytes;
4use core::time::Duration;
5use http::uri::{Authority, PathAndQuery, Scheme};
6use http::{HeaderMap, Method};
7use http_body_util::BodyExt as _;
8use http_body_util::combinators::BoxBody;
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12#[derive(Copy, Clone, Debug, Default)]
14pub struct RequestOptions {
15 pub connect_timeout: Option<Duration>,
17 pub first_byte_timeout: Option<Duration>,
19 pub between_bytes_timeout: Option<Duration>,
21}
22
23pub struct Request {
25 pub method: Method,
27 pub scheme: Option<Scheme>,
29 pub authority: Option<Authority>,
31 pub path_with_query: Option<PathAndQuery>,
33 pub headers: Arc<HeaderMap>,
35 pub options: Option<Arc<RequestOptions>>,
37 pub(crate) body: Body,
39}
40
41impl Request {
42 pub fn new(
49 method: Method,
50 scheme: Option<Scheme>,
51 authority: Option<Authority>,
52 path_with_query: Option<PathAndQuery>,
53 headers: impl Into<Arc<HeaderMap>>,
54 options: Option<Arc<RequestOptions>>,
55 body: impl Into<BoxBody<Bytes, ErrorCode>>,
56 ) -> (
57 Self,
58 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
59 ) {
60 let (tx, rx) = oneshot::channel();
61 (
62 Self {
63 method,
64 scheme,
65 authority,
66 path_with_query,
67 headers: headers.into(),
68 options,
69 body: Body::Host {
70 body: body.into(),
71 result_tx: tx,
72 },
73 },
74 async {
75 let Ok(fut) = rx.await else { return Ok(()) };
76 Box::into_pin(fut).await
77 },
78 )
79 }
80
81 pub fn from_http<T>(
88 req: http::Request<T>,
89 ) -> (
90 Self,
91 impl Future<Output = Result<(), ErrorCode>> + Send + 'static,
92 )
93 where
94 T: http_body::Body<Data = Bytes> + Send + Sync + 'static,
95 T::Error: Into<ErrorCode>,
96 {
97 let (
98 http::request::Parts {
99 method,
100 uri,
101 headers,
102 ..
103 },
104 body,
105 ) = req.into_parts();
106 let http::uri::Parts {
107 scheme,
108 authority,
109 path_and_query,
110 ..
111 } = uri.into_parts();
112 Self::new(
113 method,
114 scheme,
115 authority,
116 path_and_query,
117 headers,
118 None,
119 body.map_err(Into::into).boxed(),
120 )
121 }
122}
123
124#[cfg(feature = "default-send-request")]
136pub async fn default_send_request(
137 mut req: http::Request<impl http_body::Body<Data = Bytes, Error = ErrorCode> + Send + 'static>,
138 options: Option<RequestOptions>,
139) -> Result<
140 (
141 http::Response<impl http_body::Body<Data = Bytes, Error = ErrorCode>>,
142 impl Future<Output = Result<(), ErrorCode>> + Send,
143 ),
144 ErrorCode,
145> {
146 use core::future::poll_fn;
147 use core::pin::{Pin, pin};
148 use core::task::{Poll, ready};
149 use tokio::io::{AsyncRead, AsyncWrite};
150 use tokio::net::TcpStream;
151
152 trait TokioStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
153 fn boxed(self) -> Box<dyn TokioStream>
154 where
155 Self: Sized,
156 {
157 Box::new(self)
158 }
159 }
160 impl<T> TokioStream for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
161
162 fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
163 ErrorCode::DnsError(crate::p3::bindings::http::types::DnsErrorPayload {
164 rcode: Some(rcode),
165 info_code: Some(info_code),
166 })
167 }
168
169 let uri = req.uri();
170 let authority = uri.authority().ok_or(ErrorCode::HttpRequestUriInvalid)?;
171 let use_tls = uri.scheme() == Some(&Scheme::HTTPS);
172 let authority = if authority.port().is_some() {
173 authority.to_string()
174 } else {
175 let port = if use_tls { 443 } else { 80 };
176 format!("{authority}:{port}")
177 };
178
179 let connect_timeout = options
180 .and_then(
181 |RequestOptions {
182 connect_timeout, ..
183 }| connect_timeout,
184 )
185 .unwrap_or(Duration::from_secs(600));
186
187 let first_byte_timeout = options
188 .and_then(
189 |RequestOptions {
190 first_byte_timeout, ..
191 }| first_byte_timeout,
192 )
193 .unwrap_or(Duration::from_secs(600));
194
195 let between_bytes_timeout = options
196 .and_then(
197 |RequestOptions {
198 between_bytes_timeout,
199 ..
200 }| between_bytes_timeout,
201 )
202 .unwrap_or(Duration::from_secs(600));
203
204 let stream = match tokio::time::timeout(connect_timeout, TcpStream::connect(&authority)).await {
205 Ok(Ok(stream)) => stream,
206 Ok(Err(err)) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
207 return Err(dns_error("address not available".to_string(), 0));
208 }
209 Ok(Err(err))
210 if err
211 .to_string()
212 .starts_with("failed to lookup address information") =>
213 {
214 return Err(dns_error("address not available".to_string(), 0));
215 }
216 Ok(Err(err)) => {
217 tracing::warn!(?err, "connection refused");
218 return Err(ErrorCode::ConnectionRefused);
219 }
220 Err(..) => return Err(ErrorCode::ConnectionTimeout),
221 };
222 let stream = if use_tls {
223 use rustls::pki_types::ServerName;
224
225 let root_cert_store = rustls::RootCertStore {
227 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
228 };
229 let config = rustls::ClientConfig::builder()
230 .with_root_certificates(root_cert_store)
231 .with_no_client_auth();
232 let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
233 let mut parts = authority.split(":");
234 let host = parts.next().unwrap_or(&authority);
235 let domain = ServerName::try_from(host)
236 .map_err(|e| {
237 tracing::warn!("dns lookup error: {e:?}");
238 dns_error("invalid dns name".to_string(), 0)
239 })?
240 .to_owned();
241 let stream = connector.connect(domain, stream).await.map_err(|e| {
242 tracing::warn!("tls protocol error: {e:?}");
243 ErrorCode::TlsProtocolError
244 })?;
245 stream.boxed()
246 } else {
247 stream.boxed()
248 };
249 let (mut sender, conn) = tokio::time::timeout(
250 connect_timeout,
251 hyper::client::conn::http1::Builder::new().handshake(crate::io::TokioIo::new(stream)),
253 )
254 .await
255 .map_err(|_| ErrorCode::ConnectionTimeout)?
256 .map_err(ErrorCode::from_hyper_request_error)?;
257
258 *req.uri_mut() = http::Uri::builder()
262 .path_and_query(
263 req.uri()
264 .path_and_query()
265 .map(|p| p.as_str())
266 .unwrap_or("/"),
267 )
268 .build()
269 .expect("comes from valid request");
270
271 let send = async move {
272 use core::task::Context;
273
274 struct IncomingResponseBody {
277 incoming: hyper::body::Incoming,
278 timeout: tokio::time::Interval,
279 }
280 impl http_body::Body for IncomingResponseBody {
281 type Data = <hyper::body::Incoming as http_body::Body>::Data;
282 type Error = ErrorCode;
283
284 fn poll_frame(
285 mut self: Pin<&mut Self>,
286 cx: &mut Context<'_>,
287 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
288 match Pin::new(&mut self.as_mut().incoming).poll_frame(cx) {
289 Poll::Ready(None) => Poll::Ready(None),
290 Poll::Ready(Some(Err(err))) => {
291 Poll::Ready(Some(Err(ErrorCode::from_hyper_response_error(err))))
292 }
293 Poll::Ready(Some(Ok(frame))) => {
294 self.timeout.reset();
295 Poll::Ready(Some(Ok(frame)))
296 }
297 Poll::Pending => {
298 ready!(self.timeout.poll_tick(cx));
299 Poll::Ready(Some(Err(ErrorCode::ConnectionReadTimeout)))
300 }
301 }
302 }
303 fn is_end_stream(&self) -> bool {
304 self.incoming.is_end_stream()
305 }
306 fn size_hint(&self) -> http_body::SizeHint {
307 self.incoming.size_hint()
308 }
309 }
310
311 let res = tokio::time::timeout(first_byte_timeout, sender.send_request(req))
312 .await
313 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
314 .map_err(ErrorCode::from_hyper_request_error)?;
315 let mut timeout = tokio::time::interval(between_bytes_timeout);
316 timeout.reset();
317 Ok(res.map(|incoming| IncomingResponseBody { incoming, timeout }))
318 };
319 let mut send = pin!(send);
320 let mut conn = Some(conn);
321 let res = poll_fn(|cx| match send.as_mut().poll(cx) {
323 Poll::Ready(Ok(res)) => Poll::Ready(Ok(res)),
324 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
325 Poll::Pending => {
326 let Some(fut) = conn.as_mut() else {
328 return Poll::Pending;
330 };
331 let res = ready!(Pin::new(fut).poll(cx));
332 conn = None;
334 match res {
335 Ok(()) => send.as_mut().poll(cx),
337 Err(err) => Poll::Ready(Err(ErrorCode::from_hyper_request_error(err))),
339 }
340 }
341 })
342 .await?;
343 Ok((res, async move {
344 let Some(conn) = conn.take() else {
345 return Ok(());
347 };
348 conn.await.map_err(ErrorCode::from_hyper_response_error)
349 }))
350}