wasmtime_wasi_tls/
lib.rs

1//! # Wasmtime's [wasi-tls] (Transport Layer Security) Implementation
2//!
3//! This crate provides the Wasmtime host implementation for the [wasi-tls] API.
4//! The [wasi-tls] world allows WebAssembly modules to perform SSL/TLS operations,
5//! such as establishing secure connections to servers. TLS often relies on other wasi networking systems
6//! to provide the stream so it will be common to enable the [wasi:cli] world as well with the networking features enabled.
7//!
8//! # An example of how to configure [wasi-tls] is the following:
9//!
10//! ```rust
11//! use wasmtime_wasi::{IoView, WasiCtx, WasiCtxBuilder, WasiView};
12//! use wasmtime::{
13//!     component::{Linker, ResourceTable},
14//!     Store, Engine, Result, Config
15//! };
16//! use wasmtime_wasi_tls::{LinkOptions, WasiTlsCtx};
17//!
18//! struct Ctx {
19//!     table: ResourceTable,
20//!     wasi_ctx: WasiCtx,
21//! }
22//!
23//! impl IoView for Ctx {
24//!     fn table(&mut self) -> &mut ResourceTable {
25//!         &mut self.table
26//!     }
27//! }
28//!
29//! impl WasiView for Ctx {
30//!     fn ctx(&mut self) -> &mut WasiCtx {
31//!         &mut self.wasi_ctx
32//!     }
33//! }
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<()> {
37//!     let ctx = Ctx {
38//!         table: ResourceTable::new(),
39//!         wasi_ctx: WasiCtxBuilder::new()
40//!             .inherit_stderr()
41//!             .inherit_network()
42//!             .allow_ip_name_lookup(true)
43//!             .build(),
44//!     };
45//!
46//!     let mut config = Config::new();
47//!     config.async_support(true);
48//!     let engine = Engine::new(&config)?;
49//!
50//!     // Set up wasi-cli
51//!     let mut store = Store::new(&engine, ctx);
52//!     let mut linker = Linker::new(&engine);
53//!     wasmtime_wasi::add_to_linker_async(&mut linker)?;
54//!
55//!     // Add wasi-tls types and turn on the feature in linker
56//!     let mut opts = LinkOptions::default();
57//!     opts.tls(true);
58//!     wasmtime_wasi_tls::add_to_linker(&mut linker, &mut opts, |h: &mut Ctx| {
59//!         WasiTlsCtx::new(&mut h.table)
60//!     })?;
61//!
62//!     // ... use `linker` to instantiate within `store` ...
63//!     Ok(())
64//! }
65//!
66//! ```
67//! [wasi-tls]: https://github.com/WebAssembly/wasi-tls
68//! [wasi:cli]: https://docs.rs/wasmtime-wasi/latest
69
70#![deny(missing_docs)]
71#![doc(test(attr(deny(warnings))))]
72#![doc(test(attr(allow(dead_code, unused_variables, unused_mut))))]
73
74use anyhow::{Context, Result};
75use bytes::Bytes;
76use rustls::pki_types::ServerName;
77use std::io;
78use std::sync::Arc;
79use std::task::{ready, Poll};
80use std::{future::Future, mem, pin::Pin, sync::LazyLock};
81use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
82use tokio::sync::Mutex;
83use tokio_rustls::client::TlsStream;
84use wasmtime::component::{Resource, ResourceTable};
85use wasmtime_wasi::pipe::AsyncReadStream;
86use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
87use wasmtime_wasi::OutputStream;
88use wasmtime_wasi::{
89    async_trait,
90    bindings::io::{
91        poll::Pollable as HostPollable,
92        streams::{InputStream as BoxInputStream, OutputStream as BoxOutputStream},
93    },
94    Pollable, StreamError,
95};
96
97mod gen_ {
98    wasmtime::component::bindgen!({
99        path: "wit/",
100        world: "imports",
101        with: {
102            "wasi:io": wasmtime_wasi::bindings::io,
103            "wasi:tls/types/client-connection": super::ClientConnection,
104            "wasi:tls/types/client-handshake": super::ClientHandShake,
105            "wasi:tls/types/future-client-streams": super::FutureClientStreams,
106        },
107        trappable_imports: true,
108        async: {
109            only_imports: [],
110        }
111    });
112}
113pub use gen_::wasi::tls::types::LinkOptions;
114use gen_::wasi::tls::{self as generated};
115
116fn default_client_config() -> Arc<rustls::ClientConfig> {
117    static CONFIG: LazyLock<Arc<rustls::ClientConfig>> = LazyLock::new(|| {
118        let roots = rustls::RootCertStore {
119            roots: webpki_roots::TLS_SERVER_ROOTS.into(),
120        };
121        let config = rustls::ClientConfig::builder()
122            .with_root_certificates(roots)
123            .with_no_client_auth();
124        Arc::new(config)
125    });
126    Arc::clone(&CONFIG)
127}
128
129/// Wasi TLS context needed fro internal `wasi-tls`` state
130pub struct WasiTlsCtx<'a> {
131    table: &'a mut ResourceTable,
132}
133
134impl<'a> WasiTlsCtx<'a> {
135    /// Create a new Wasi TLS context
136    pub fn new(table: &'a mut ResourceTable) -> Self {
137        Self { table }
138    }
139}
140
141impl<'a> generated::types::Host for WasiTlsCtx<'a> {}
142
143/// Add the `wasi-tls` world's types to a [`wasmtime::component::Linker`].
144pub fn add_to_linker<T: Send>(
145    l: &mut wasmtime::component::Linker<T>,
146    opts: &mut LinkOptions,
147    f: impl Fn(&mut T) -> WasiTlsCtx + Send + Sync + Copy + 'static,
148) -> Result<()> {
149    generated::types::add_to_linker_get_host(l, &opts, f)?;
150    Ok(())
151}
152///  Represents the ClientHandshake which will be used to configure the handshake
153pub struct ClientHandShake {
154    server_name: String,
155    streams: WasiStreams,
156}
157
158impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
159    fn new(
160        &mut self,
161        server_name: String,
162        input: Resource<BoxInputStream>,
163        output: Resource<BoxOutputStream>,
164    ) -> wasmtime::Result<Resource<ClientHandShake>> {
165        let input = self.table.delete(input)?;
166        let output = self.table.delete(output)?;
167        Ok(self.table.push(ClientHandShake {
168            server_name,
169            streams: WasiStreams {
170                input: StreamState::Ready(input),
171                output: StreamState::Ready(output),
172            },
173        })?)
174    }
175
176    fn finish(
177        &mut self,
178        this: wasmtime::component::Resource<ClientHandShake>,
179    ) -> wasmtime::Result<Resource<FutureClientStreams>> {
180        let handshake = self.table.delete(this)?;
181        let server_name = handshake.server_name;
182        let streams = handshake.streams;
183        let domain = ServerName::try_from(server_name)?;
184
185        Ok(self
186            .table
187            .push(FutureStreams(StreamState::Pending(Box::pin(async move {
188                let connector = tokio_rustls::TlsConnector::from(default_client_config());
189                connector
190                    .connect(domain, streams)
191                    .await
192                    .with_context(|| "connection failed")
193            }))))?)
194    }
195
196    fn drop(
197        &mut self,
198        this: wasmtime::component::Resource<ClientHandShake>,
199    ) -> wasmtime::Result<()> {
200        self.table.delete(this)?;
201        Ok(())
202    }
203}
204
205/// Future streams provides the tls streams after the handshake is completed
206pub struct FutureStreams<T>(StreamState<Result<T>>);
207
208/// Library specific version of TLS connection after the handshake is completed.
209/// This alias allows it to use with wit-bindgen component generator which won't take generic types
210pub type FutureClientStreams = FutureStreams<TlsStream<WasiStreams>>;
211
212#[async_trait]
213impl<T: Send + 'static> Pollable for FutureStreams<T> {
214    async fn ready(&mut self) {
215        match &mut self.0 {
216            StreamState::Ready(_) | StreamState::Closed => return,
217            StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await),
218        }
219    }
220}
221
222impl<'a> generated::types::HostFutureClientStreams for WasiTlsCtx<'a> {
223    fn subscribe(
224        &mut self,
225        this: wasmtime::component::Resource<FutureClientStreams>,
226    ) -> wasmtime::Result<Resource<HostPollable>> {
227        wasmtime_wasi::subscribe(self.table, this)
228    }
229
230    fn get(
231        &mut self,
232        this: wasmtime::component::Resource<FutureClientStreams>,
233    ) -> wasmtime::Result<
234        Option<
235            Result<
236                Result<
237                    (
238                        Resource<ClientConnection>,
239                        Resource<BoxInputStream>,
240                        Resource<BoxOutputStream>,
241                    ),
242                    (),
243                >,
244                (),
245            >,
246        >,
247    > {
248        {
249            let this = self.table.get(&this)?;
250            match &this.0 {
251                StreamState::Pending(_) => return Ok(None),
252                StreamState::Ready(Ok(_)) => (),
253                StreamState::Ready(Err(_)) => {
254                    return Ok(Some(Ok(Err(()))));
255                }
256                StreamState::Closed => return Ok(Some(Err(()))),
257            }
258        }
259
260        let StreamState::Ready(Ok(tls_stream)) =
261            mem::replace(&mut self.table.get_mut(&this)?.0, StreamState::Closed)
262        else {
263            unreachable!()
264        };
265
266        let (rx, tx) = tokio::io::split(tls_stream);
267        let write_stream = AsyncTlsWriteStream::new(TlsWriter::new(tx));
268        let client = ClientConnection {
269            writer: write_stream.clone(),
270        };
271
272        let input = Box::new(AsyncReadStream::new(rx)) as BoxInputStream;
273        let output = Box::new(write_stream) as BoxOutputStream;
274
275        let client = self.table.push(client)?;
276        let input = self.table.push_child(input, &client)?;
277        let output = self.table.push_child(output, &client)?;
278
279        Ok(Some(Ok(Ok((client, input, output)))))
280    }
281
282    fn drop(
283        &mut self,
284        this: wasmtime::component::Resource<FutureClientStreams>,
285    ) -> wasmtime::Result<()> {
286        self.table.delete(this)?;
287        Ok(())
288    }
289}
290
291/// Represents the client connection and used to shut down the tls stream
292pub struct ClientConnection {
293    writer: AsyncTlsWriteStream,
294}
295
296impl<'a> generated::types::HostClientConnection for WasiTlsCtx<'a> {
297    fn close_output(&mut self, this: Resource<ClientConnection>) -> wasmtime::Result<()> {
298        self.table.get_mut(&this)?.writer.close()
299    }
300
301    fn drop(&mut self, this: Resource<ClientConnection>) -> wasmtime::Result<()> {
302        self.table.delete(this)?;
303        Ok(())
304    }
305}
306
307enum StreamState<T> {
308    Ready(T),
309    Pending(Pin<Box<dyn Future<Output = T> + Send>>),
310    Closed,
311}
312
313/// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write
314pub struct WasiStreams {
315    input: StreamState<BoxInputStream>,
316    output: StreamState<BoxOutputStream>,
317}
318
319impl AsyncWrite for WasiStreams {
320    fn poll_write(
321        mut self: Pin<&mut Self>,
322        cx: &mut std::task::Context<'_>,
323        buf: &[u8],
324    ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> {
325        loop {
326            match &mut self.as_mut().output {
327                StreamState::Closed => unreachable!(),
328                StreamState::Pending(future) => {
329                    let value = ready!(future.as_mut().poll(cx));
330                    self.as_mut().output = StreamState::Ready(value);
331                }
332                StreamState::Ready(output) => {
333                    match output.check_write() {
334                        Ok(0) => {
335                            let StreamState::Ready(mut output) =
336                                mem::replace(&mut self.as_mut().output, StreamState::Closed)
337                            else {
338                                unreachable!()
339                            };
340                            self.as_mut().output = StreamState::Pending(Box::pin(async move {
341                                output.ready().await;
342                                output
343                            }));
344                        }
345                        Ok(count) => {
346                            let count = count.min(buf.len());
347                            return match output.write(Bytes::copy_from_slice(&buf[..count])) {
348                                Ok(()) => Poll::Ready(Ok(count)),
349                                Err(StreamError::Closed) => Poll::Ready(Ok(0)),
350                                Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
351                                    Poll::Ready(Err(std::io::Error::other(e)))
352                                }
353                            };
354                        }
355                        Err(StreamError::Closed) => return Poll::Ready(Ok(0)),
356                        Err(StreamError::LastOperationFailed(e) | StreamError::Trap(e)) => {
357                            return Poll::Ready(Err(std::io::Error::other(e)))
358                        }
359                    };
360                }
361            }
362        }
363    }
364
365    fn poll_flush(
366        self: Pin<&mut Self>,
367        cx: &mut std::task::Context<'_>,
368    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
369        self.poll_write(cx, &[]).map(|v| v.map(drop))
370    }
371
372    fn poll_shutdown(
373        self: Pin<&mut Self>,
374        cx: &mut std::task::Context<'_>,
375    ) -> std::task::Poll<std::result::Result<(), std::io::Error>> {
376        self.poll_flush(cx)
377    }
378}
379
380impl AsyncRead for WasiStreams {
381    fn poll_read(
382        mut self: Pin<&mut Self>,
383        cx: &mut std::task::Context<'_>,
384        buf: &mut tokio::io::ReadBuf<'_>,
385    ) -> std::task::Poll<std::io::Result<()>> {
386        loop {
387            let stream = match &mut self.input {
388                StreamState::Ready(stream) => stream,
389                StreamState::Pending(fut) => {
390                    let stream = ready!(fut.as_mut().poll(cx));
391                    self.input = StreamState::Ready(stream);
392                    if let StreamState::Ready(stream) = &mut self.input {
393                        stream
394                    } else {
395                        unreachable!()
396                    }
397                }
398                StreamState::Closed => {
399                    return Poll::Ready(Ok(()));
400                }
401            };
402            match stream.read(buf.remaining()) {
403                Ok(bytes) if bytes.is_empty() => {
404                    let StreamState::Ready(mut stream) =
405                        std::mem::replace(&mut self.input, StreamState::Closed)
406                    else {
407                        unreachable!()
408                    };
409
410                    self.input = StreamState::Pending(Box::pin(async move {
411                        stream.ready().await;
412                        stream
413                    }));
414                }
415                Ok(bytes) => {
416                    buf.put_slice(&bytes);
417
418                    return Poll::Ready(Ok(()));
419                }
420                Err(StreamError::Closed) => {
421                    self.input = StreamState::Closed;
422                    return Poll::Ready(Ok(()));
423                }
424                Err(e) => {
425                    self.input = StreamState::Closed;
426                    return Poll::Ready(Err(std::io::Error::other(e)));
427                }
428            }
429        }
430    }
431}
432
433type TlsWriteHalf = tokio::io::WriteHalf<tokio_rustls::client::TlsStream<WasiStreams>>;
434
435struct TlsWriter {
436    state: WriteState,
437}
438
439enum WriteState {
440    Ready(TlsWriteHalf),
441    Writing(AbortOnDropJoinHandle<io::Result<TlsWriteHalf>>),
442    Closing(AbortOnDropJoinHandle<io::Result<()>>),
443    Closed,
444    Error(io::Error),
445}
446const READY_SIZE: usize = 1024 * 1024 * 1024;
447
448impl TlsWriter {
449    fn new(stream: TlsWriteHalf) -> Self {
450        Self {
451            state: WriteState::Ready(stream),
452        }
453    }
454
455    fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
456        let WriteState::Ready(_) = self.state else {
457            return Err(StreamError::Trap(anyhow::anyhow!(
458                "unpermitted: must call check_write first"
459            )));
460        };
461
462        if bytes.is_empty() {
463            return Ok(());
464        }
465
466        let WriteState::Ready(mut stream) = std::mem::replace(&mut self.state, WriteState::Closed)
467        else {
468            unreachable!()
469        };
470
471        self.state = WriteState::Writing(wasmtime_wasi::runtime::spawn(async move {
472            while !bytes.is_empty() {
473                match stream.write(&bytes).await {
474                    Ok(n) => {
475                        let _ = bytes.split_to(n);
476                    }
477                    Err(e) => return Err(e.into()),
478                }
479            }
480
481            Ok(stream)
482        }));
483
484        Ok(())
485    }
486
487    fn flush(&mut self) -> Result<(), StreamError> {
488        // `flush` is a no-op here, as we're not managing any internal buffer.
489        match self.state {
490            WriteState::Ready(_)
491            | WriteState::Writing(_)
492            | WriteState::Closing(_)
493            | WriteState::Error(_) => Ok(()),
494            WriteState::Closed => Err(StreamError::Closed),
495        }
496    }
497
498    fn check_write(&mut self) -> Result<usize, StreamError> {
499        match &mut self.state {
500            WriteState::Ready(_) => Ok(READY_SIZE),
501            WriteState::Writing(_) => Ok(0),
502            WriteState::Closing(_) => Ok(0),
503            WriteState::Closed => Err(StreamError::Closed),
504            WriteState::Error(_) => {
505                let WriteState::Error(e) = std::mem::replace(&mut self.state, WriteState::Closed)
506                else {
507                    unreachable!()
508                };
509
510                Err(StreamError::LastOperationFailed(e.into()))
511            }
512        }
513    }
514
515    fn close(&mut self) {
516        match std::mem::replace(&mut self.state, WriteState::Closed) {
517            // No write in progress, immediately shut down:
518            WriteState::Ready(mut stream) => {
519                self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
520                    stream.shutdown().await
521                }));
522            }
523
524            // Schedule the shutdown after the current write has finished:
525            WriteState::Writing(write) => {
526                self.state = WriteState::Closing(wasmtime_wasi::runtime::spawn(async move {
527                    let mut stream = write.await?;
528                    stream.shutdown().await
529                }));
530            }
531
532            WriteState::Closing(t) => {
533                self.state = WriteState::Closing(t);
534            }
535            WriteState::Closed | WriteState::Error(_) => {}
536        }
537    }
538
539    async fn cancel(&mut self) {
540        match std::mem::replace(&mut self.state, WriteState::Closed) {
541            WriteState::Writing(task) => _ = task.cancel().await,
542            WriteState::Closing(task) => _ = task.cancel().await,
543            _ => {}
544        }
545    }
546
547    async fn ready(&mut self) {
548        match &mut self.state {
549            WriteState::Writing(task) => {
550                self.state = match task.await {
551                    Ok(s) => WriteState::Ready(s),
552                    Err(e) => WriteState::Error(e),
553                }
554            }
555            WriteState::Closing(task) => {
556                self.state = match task.await {
557                    Ok(()) => WriteState::Closed,
558                    Err(e) => WriteState::Error(e),
559                }
560            }
561            _ => {}
562        }
563    }
564}
565
566#[derive(Clone)]
567struct AsyncTlsWriteStream(Arc<Mutex<TlsWriter>>);
568
569impl AsyncTlsWriteStream {
570    fn new(writer: TlsWriter) -> Self {
571        AsyncTlsWriteStream(Arc::new(Mutex::new(writer)))
572    }
573
574    fn close(&mut self) -> wasmtime::Result<()> {
575        try_lock_for_stream(&self.0)?.close();
576        Ok(())
577    }
578}
579
580#[async_trait]
581impl OutputStream for AsyncTlsWriteStream {
582    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
583        try_lock_for_stream(&self.0)?.write(bytes)
584    }
585
586    fn flush(&mut self) -> Result<(), StreamError> {
587        try_lock_for_stream(&self.0)?.flush()
588    }
589
590    fn check_write(&mut self) -> Result<usize, StreamError> {
591        try_lock_for_stream(&self.0)?.check_write()
592    }
593
594    async fn cancel(&mut self) {
595        self.0.lock().await.cancel().await
596    }
597}
598
599#[async_trait]
600impl Pollable for AsyncTlsWriteStream {
601    async fn ready(&mut self) {
602        self.0.lock().await.ready().await
603    }
604}
605
606fn try_lock_for_stream<TlsWriter>(
607    mutex: &Mutex<TlsWriter>,
608) -> Result<tokio::sync::MutexGuard<'_, TlsWriter>, StreamError> {
609    mutex
610        .try_lock()
611        .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use tokio::sync::oneshot;
618
619    #[tokio::test]
620    async fn test_future_client_streams_ready_can_be_canceled() {
621        let (tx1, rx1) = oneshot::channel::<()>();
622
623        let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move {
624            rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled"))
625        })));
626
627        let mut fut = future_streams.ready();
628
629        let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
630        assert!(fut.as_mut().poll(&mut cx).is_pending());
631
632        //cancel the readiness check
633        drop(fut);
634
635        match future_streams.0 {
636            StreamState::Closed => panic!("First future should be in Pending/ready state"),
637            _ => (),
638        }
639
640        // make it ready and wait for it to progress
641        tx1.send(()).unwrap();
642        future_streams.ready().await;
643
644        match future_streams.0 {
645            StreamState::Ready(Ok(())) => (),
646            _ => panic!("First future should be in Ready(Err) state"),
647        }
648    }
649}