1#![deny(missing_docs)]
71#![doc(test(attr(deny(warnings))))]
72#![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))]
73
74use anyhow::{Context, Result};
75use bytes::Bytes;
76use rustls::pki_types::ServerName;
77use std::io;
78use std::sync::Arc;
79use std::task::{ready, Poll};
80use std::{future::Future, mem, pin::Pin, sync::LazyLock};
81use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
82use tokio::sync::Mutex;
83use tokio_rustls::client::TlsStream;
84use wasmtime::component::{Resource, ResourceTable};
85use wasmtime_wasi::pipe::AsyncReadStream;
86use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
87use wasmtime_wasi::OutputStream;
88use wasmtime_wasi::{
89 async_trait,
90 bindings::io::{
91 poll::Pollable as HostPollable,
92 streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream},
93 },
94 Pollable, StreamError,
95};
96
97mod gen_ {
98 wasmtime::component::bindgen!({
99 path: "wit/",
100 world: "imports",
101 with: {
102 "wasi:io": wasmtime_wasi::bindings::io,
103 "wasi:tls/types/client-connection": super::ClientConnection,
104 "wasi:tls/types/client-handshake": super::ClientHandShake,
105 "wasi:tls/types/future-client-streams": super::FutureClientStreams,
106 },
107 trappable_imports: true,
108 async: {
109 only_imports: [],
110 }
111 });
112}
113pub use gen_::wasi::tls::types::LinkOptions;
114use gen_::wasi::tls::{self as generated};
115
116fn default_client_config() -> Arc<rustls::ClientConfig> {
117 static CONFIG: LazyLock<Arc<rustls::ClientConfig>> = LazyLock::new(|| {
118 let roots = rustls::RootCertStore {
119 roots: webpki_roots::TLS_SERVER_ROOTS.into(),
120 };
121 let config = rustls::ClientConfig::builder()
122 .with_root_certificates(roots)
123 .with_no_client_auth();
124 Arc::new(config)
125 });
126 Arc::clone(&CONFIG)
127}
128
129pub struct WasiTlsCtx<'a> {
131 table: &'a mut ResourceTable,
132}
133
134impl<'a> WasiTlsCtx<'a> {
135 pub fn new(table: &'a mut ResourceTable) -> Self {
137 Self { table }
138 }
139}
140
141impl<'a> generated::types::Host for WasiTlsCtx<'a> {}
142
143pub fn add_to_linker<T: Send>(
145 l: &mut wasmtime::component::Linker<T>,
146 opts: &mut LinkOptions,
147 f: impl Fn(&mut T) -> WasiTlsCtx + Send + Sync + Copy + 'static,
148) -> Result<()> {
149 generated::types::add_to_linker_get_host(l, &opts, f)?;
150 Ok(())
151}
152pub struct ClientHandShake {
154 server_name: String,
155 streams: WasiStreams,
156}
157
158impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
159 fn new(
160 &mut self,
161 server_name: String,
162 input: Resource<BoxInputStream>,
163 output: Resource<BoxOutputStream>,
164 ) -> wasmtime::Result<Resource<ClientHandShake>> {
165 let input = self.table.delete(input)?;
166 let output = self.table.delete(output)?;
167 Ok(self.table.push(ClientHandShake {
168 server_name,
169 streams: WasiStreams {
170 input: StreamState::Ready(input),
171 output: StreamState::Ready(output),
172 },
173 })?)
174 }
175
176 fn finish(
177 &mut self,
178 this: wasmtime::component::Resource<ClientHandShake>,
179 ) -> wasmtime::Result<Resource<FutureClientStreams>> {
180 let handshake = self.table.delete(this)?;
181 let server_name = handshake.server_name;
182 let streams = handshake.streams;
183 let domain = ServerName::try_from(server_name)?;
184
185 Ok(self
186 .table
187 .push(FutureStreams(StreamState::Pending(Box::pin(async move {
188 let connector = tokio_rustls::TlsConnector::from(default_client_config());
189 connector
190 .connect(domain, streams)
191 .await
192 .with_context(|| "connection failed")
193 }))))?)
194 }
195
196 fn drop(
197 &mut self,
198 this: wasmtime::component::Resource<ClientHandShake>,
199 ) -> wasmtime::Result<()> {
200 self.table.delete(this)?;
201 Ok(())
202 }
203}
204
205pub struct FutureStreams<T>(StreamState<Result<T>>);
207
208pub type FutureClientStreams = FutureStreams<TlsStream<WasiStreams>>;
211
212#[async_trait]
213impl<T: Send + 'static> Pollable for FutureStreams<T> {
214 async fn ready(&mut self) {
215 match &mut self.0 {
216 StreamState::Ready(_) | StreamState::Closed => return,
217 StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await),
218 }
219 }
220}
221
222impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> {
223 fn subscribe(
224 &mut self,
225 this: wasmtime::component::Resource<FutureClientStreams>,
226 ) -> wasmtime::Result<Resource<HostPollable>> {
227 wasmtime_wasi::subscribe(self.table, this)
228 }
229
230 fn get(
231 &mut self,
232 this: wasmtime::component::Resource<FutureClientStreams>,
233 ) -> wasmtime::Result<
234 Option<
235 Result<
236 Result<
237 (
238 Resource<ClientConnection>,
239 Resource<BoxInputStream>,
240 Resource<BoxOutputStream>,
241 ),
242 (),
243 >,
244 (),
245 >,
246 >,
247 > {
248 {
249 let this = self.table.get(&this)?;
250 match &this.0 {
251 StreamState::Pending(_) => return Ok(None),
252 StreamState::Ready(Ok(_)) => (),
253 StreamState::Ready(Err(_)) => {
254 return Ok(Some(Ok(Err(()))));
255 }
256 StreamState::Closed => return Ok(Some(Err(()))),
257 }
258 }
259
260 let StreamState::Ready(Ok(tls_stream)) =
261 mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed)
262 else {
263 unreachable!()
264 };
265
266 let (rx, tx) = tokio::io::split(tls_stream);
267 let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx));
268 let client = ClientConnection {
269 writer: write_stream.clone(),
270 };
271
272 let input = Box::new(AsyncReadStream::new(rx)) as BoxInputStream;
273 let output = Box::new(write_stream) as BoxOutputStream;
274
275 let client = self.table.push(client)?;
276 let input = self.table.push_child(input, &client)?;
277 let output = self.table.push_child(output, &client)?;
278
279 Ok(Some(Ok(Ok((client, input, output)))))
280 }
281
282 fn drop(
283 &mut self,
284 this: wasmtime::component::Resource<FutureClientStreams>,
285 ) -> wasmtime::Result<()> {
286 self.table.delete(this)?;
287 Ok(())
288 }
289}
290
291pub struct ClientConnection {
293 writer: AsyncTlsWriteStream,
294}
295
296impl<'a> generated::types::HostClientConnection for WasiTlsCtx<'a> {
297 fn close_output(&mut self, this: Resource<ClientConnection>) -> wasmtime::Result<()> {
298 self.table.get_mut(&this)?.writer.close()
299 }
300
301 fn drop(&mut self, this: Resource<ClientConnection>) -> wasmtime::Result<()> {
302 self.table.delete(this)?;
303 Ok(())
304 }
305}
306
307enum StreamState<T> {
308 Ready(T),
309 Pending(Pin<Box<dyn Future<Output = T> + Send>>),
310 Closed,
311}
312
313pub struct WasiStreams {
315 input: StreamState<BoxInputStream>,
316 output: StreamState<BoxOutputStream>,
317}
318
319impl AsyncWrite for WasiStreams {
320 fn poll_write(
321 mut self: Pin<&mut Self>,
322 cx: &mut std::task::Context<'_>,
323 buf: &[u8],
324 ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
325 loop {
326 match &mut self.as_mut().output {
327 StreamState::Closed => unreachable!(),
328 StreamState::Pending(future) => {
329 let value = ready!(future.as_mut().poll(cx));
330 self.as_mut().output = StreamState::Ready(value);
331 }
332 StreamState::Ready(output) => {
333 match output.check_write() {
334 Ok(0) => {
335 let StreamState::Ready(mut output) =
336 mem::replace(&mut self.as_mut().output, StreamState::Closed)
337 else {
338 unreachable!()
339 };
340 self.as_mut().output = StreamState::Pending(Box::pin(async move {
341 output.ready().await;
342 output
343 }));
344 }
345 Ok(count) => {
346 let count = count.min(buf.len());
347 return match output.write(Bytes::copy_from_slice(&buf[..count])) {
348 Ok(()) => Poll::Ready(Ok(count)),
349 Err(StreamError::Closed) => Poll::Ready(Ok(0)),
350 Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
351 Poll::Ready(Err(std::io::Error::other(e)))
352 }
353 };
354 }
355 Err(StreamError::Closed) => return Poll::Ready(Ok(0)),
356 Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
357 return Poll::Ready(Err(std::io::Error::other(e)))
358 }
359 };
360 }
361 }
362 }
363 }
364
365 fn poll_flush(
366 self: Pin<&mut Self>,
367 cx: &mut std::task::Context<'_>,
368 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
369 self.poll_write(cx, &[]).map(|v| v.map(drop))
370 }
371
372 fn poll_shutdown(
373 self: Pin<&mut Self>,
374 cx: &mut std::task::Context<'_>,
375 ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
376 self.poll_flush(cx)
377 }
378}
379
380impl AsyncRead for WasiStreams {
381 fn poll_read(
382 mut self: Pin<&mut Self>,
383 cx: &mut std::task::Context<'_>,
384 buf: &mut tokio::io::ReadBuf<'_>,
385 ) -> std::task::Poll<std::io::Result<()>> {
386 loop {
387 let stream = match &mut self.input {
388 StreamState::Ready(stream) => stream,
389 StreamState::Pending(fut) => {
390 let stream = ready!(fut.as_mut().poll(cx));
391 self.input = StreamState::Ready(stream);
392 if let StreamState::Ready(stream) = &mut self.input {
393 stream
394 } else {
395 unreachable!()
396 }
397 }
398 StreamState::Closed => {
399 return Poll::Ready(Ok(()));
400 }
401 };
402 match stream.read(buf.remaining()) {
403 Ok(bytes) if bytes.is_empty() => {
404 let StreamState::Ready(mut stream) =
405 std::mem::replace(&mut self.input, StreamState::Closed)
406 else {
407 unreachable!()
408 };
409
410 self.input = StreamState::Pending(Box::pin(async move {
411 stream.ready().await;
412 stream
413 }));
414 }
415 Ok(bytes) => {
416 buf.put_slice(&bytes);
417
418 return Poll::Ready(Ok(()));
419 }
420 Err(StreamError::Closed) => {
421 self.input = StreamState::Closed;
422 return Poll::Ready(Ok(()));
423 }
424 Err(e) => {
425 self.input = StreamState::Closed;
426 return Poll::Ready(Err(std::io::Error::other(e)));
427 }
428 }
429 }
430 }
431}
432
433type TlsWriteHalf = tokio::io::WriteHalf<tokio_rustls::client::TlsStream<WasiStreams>>;
434
435struct TlsWriter {
436 state: WriteState,
437}
438
439enum WriteState {
440 Ready(TlsWriteHalf),
441 Writing(AbortOnDropJoinHandle<io::Result<TlsWriteHalf>>),
442 Closing(AbortOnDropJoinHandle<io::Result<()>>),
443 Closed,
444 Error(io::Error),
445}
446const READY_SIZE: usize = 1024 * 1024 * 1024;
447
448impl TlsWriter {
449 fn new(stream: TlsWriteHalf) -> Self {
450 Self {
451 state: WriteState::Ready(stream),
452 }
453 }
454
455 fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
456 let WriteState::Ready(_) = self.state else {
457 return Err(StreamError::Trap(anyhow::anyhow!(
458 "unpermitted: must call check_write first"
459 )));
460 };
461
462 if bytes.is_empty() {
463 return Ok(());
464 }
465
466 let WriteState::Ready(mut stream) = std::mem::replace(&mut self.state, WriteState::Closed)
467 else {
468 unreachable!()
469 };
470
471 self.state = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move {
472 while !bytes.is_empty() {
473 match stream.write(&bytes).await {
474 Ok(n) => {
475 let _ = bytes.split_to(n);
476 }
477 Err(e) => return Err(e.into()),
478 }
479 }
480
481 Ok(stream)
482 }));
483
484 Ok(())
485 }
486
487 fn flush(&mut self) -> Result<(), StreamError> {
488 match self.state {
490 WriteState::Ready(_)
491 | WriteState::Writing(_)
492 | WriteState::Closing(_)
493 | WriteState::Error(_) => Ok(()),
494 WriteState::Closed => Err(StreamError::Closed),
495 }
496 }
497
498 fn check_write(&mut self) -> Result<usize, StreamError> {
499 match &mut self.state {
500 WriteState::Ready(_) => Ok(READY_SIZE),
501 WriteState::Writing(_) => Ok(0),
502 WriteState::Closing(_) => Ok(0),
503 WriteState::Closed => Err(StreamError::Closed),
504 WriteState::Error(_) => {
505 let WriteState::Error(e) = std::mem::replace(&mut self.state, WriteState::Closed)
506 else {
507 unreachable!()
508 };
509
510 Err(StreamError::LastOperationFailed(e.into()))
511 }
512 }
513 }
514
515 fn close(&mut self) {
516 match std::mem::replace(&mut self.state, WriteState::Closed) {
517 WriteState::Ready(mut stream) => {
519 self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
520 stream.shutdown().await
521 }));
522 }
523
524 WriteState::Writing(write) => {
526 self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
527 let mut stream = write.await?;
528 stream.shutdown().await
529 }));
530 }
531
532 WriteState::Closing(t) => {
533 self.state = WriteState::Closing(t);
534 }
535 WriteState::Closed | WriteState::Error(_) => {}
536 }
537 }
538
539 async fn cancel(&mut self) {
540 match std::mem::replace(&mut self.state, WriteState::Closed) {
541 WriteState::Writing(task) => _ = task.cancel().await,
542 WriteState::Closing(task) => _ = task.cancel().await,
543 _ => {}
544 }
545 }
546
547 async fn ready(&mut self) {
548 match &mut self.state {
549 WriteState::Writing(task) => {
550 self.state = match task.await {
551 Ok(s) => WriteState::Ready(s),
552 Err(e) => WriteState::Error(e),
553 }
554 }
555 WriteState::Closing(task) => {
556 self.state = match task.await {
557 Ok(()) => WriteState::Closed,
558 Err(e) => WriteState::Error(e),
559 }
560 }
561 _ => {}
562 }
563 }
564}
565
566#[derive(Clone)]
567struct AsyncTlsWriteStream(Arc<Mutex<TlsWriter>>);
568
569impl AsyncTlsWriteStream {
570 fn new(writer: TlsWriter) -> Self {
571 AsyncTlsWriteStream(Arc::new(Mutex::new(writer)))
572 }
573
574 fn close(&mut self) -> wasmtime::Result<()> {
575 try_lock_for_stream(&self.0)?.close();
576 Ok(())
577 }
578}
579
580#[async_trait]
581impl OutputStream for AsyncTlsWriteStream {
582 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
583 try_lock_for_stream(&self.0)?.write(bytes)
584 }
585
586 fn flush(&mut self) -> Result<(), StreamError> {
587 try_lock_for_stream(&self.0)?.flush()
588 }
589
590 fn check_write(&mut self) -> Result<usize, StreamError> {
591 try_lock_for_stream(&self.0)?.check_write()
592 }
593
594 async fn cancel(&mut self) {
595 self.0.lock().await.cancel().await
596 }
597}
598
599#[async_trait]
600impl Pollable for AsyncTlsWriteStream {
601 async fn ready(&mut self) {
602 self.0.lock().await.ready().await
603 }
604}
605
606fn try_lock_for_stream<TlsWriter>(
607 mutex: &Mutex<TlsWriter>,
608) -> Result<tokio::sync::MutexGuard<'_, TlsWriter>, StreamError> {
609 mutex
610 .try_lock()
611 .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use tokio::sync::oneshot;
618
619 #[tokio::test]
620 async fn test_future_client_streams_ready_can_be_canceled() {
621 let (tx1, rx1) = oneshot::channel::<()>();
622
623 let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move {
624 rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled"))
625 })));
626
627 let mut fut = future_streams.ready();
628
629 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
630 assert!(fut.as_mut().poll(&mut cx).is_pending());
631
632 drop(fut);
634
635 match future_streams.0 {
636 StreamState::Closed => panic!("First future should be in Pending/ready state"),
637 _ => (),
638 }
639
640 tx1.send(()).unwrap();
642 future_streams.ready().await;
643
644 match future_streams.0 {
645 StreamState::Ready(Ok(())) => (),
646 _ => panic!("First future should be in Ready(Err) state"),
647 }
648 }
649}