wasmtime_wasi/sockets/
mod.rs

1use core::future::Future;
2use core::ops::Deref;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6use wasmtime::component::{HasData, ResourceTable};
7
8mod tcp;
9mod udp;
10pub(crate) mod util;
11
12#[cfg(feature = "p3")]
13pub(crate) use tcp::NonInheritedOptions;
14pub use tcp::TcpSocket;
15pub use udp::UdpSocket;
16
17/// A helper struct which implements [`HasData`] for the `wasi:sockets` APIs.
18///
19/// This can be useful when directly calling `add_to_linker` functions directly,
20/// such as [`wasmtime_wasi::p2::bindings::sockets::tcp::add_to_linker`] as the
21/// `D` type parameter. See [`HasData`] for more information about the type
22/// parameter's purpose.
23///
24/// When using this type you can skip the [`WasiSocketsView`] trait, for
25/// example.
26///
27/// # Examples
28///
29/// ```
30/// use wasmtime::component::{Linker, ResourceTable};
31/// use wasmtime::{Engine, Result, Config};
32/// use wasmtime_wasi::sockets::*;
33///
34/// struct MyStoreState {
35///     table: ResourceTable,
36///     sockets: WasiSocketsCtx,
37/// }
38///
39/// fn main() -> Result<()> {
40///     let mut config = Config::new();
41///     config.async_support(true);
42///     let engine = Engine::new(&config)?;
43///     let mut linker = Linker::new(&engine);
44///
45///     wasmtime_wasi::p2::bindings::sockets::tcp::add_to_linker::<MyStoreState, WasiSockets>(
46///         &mut linker,
47///         |state| WasiSocketsCtxView {
48///             ctx: &mut state.sockets,
49///             table: &mut state.table,
50///         },
51///     )?;
52///     Ok(())
53/// }
54/// ```
55pub struct WasiSockets;
56
57impl HasData for WasiSockets {
58    type Data<'a> = WasiSocketsCtxView<'a>;
59}
60
61/// Value taken from rust std library.
62pub(crate) const DEFAULT_TCP_BACKLOG: u32 = 128;
63
64/// Theoretical maximum byte size of a UDP datagram, the real limit is lower,
65/// but we do not account for e.g. the transport layer here for simplicity.
66/// In practice, datagrams are typically less than 1500 bytes.
67pub(crate) const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize;
68
69#[derive(Clone, Default)]
70pub struct WasiSocketsCtx {
71    pub(crate) socket_addr_check: SocketAddrCheck,
72    pub(crate) allowed_network_uses: AllowedNetworkUses,
73}
74
75pub struct WasiSocketsCtxView<'a> {
76    pub ctx: &'a mut WasiSocketsCtx,
77    pub table: &'a mut ResourceTable,
78}
79
80pub trait WasiSocketsView: Send {
81    fn sockets(&mut self) -> WasiSocketsCtxView<'_>;
82}
83
84#[derive(Copy, Clone)]
85pub(crate) struct AllowedNetworkUses {
86    pub(crate) ip_name_lookup: bool,
87    pub(crate) udp: bool,
88    pub(crate) tcp: bool,
89}
90
91impl Default for AllowedNetworkUses {
92    fn default() -> Self {
93        Self {
94            ip_name_lookup: false,
95            udp: true,
96            tcp: true,
97        }
98    }
99}
100
101impl AllowedNetworkUses {
102    pub(crate) fn check_allowed_udp(&self) -> std::io::Result<()> {
103        if !self.udp {
104            return Err(std::io::Error::new(
105                std::io::ErrorKind::PermissionDenied,
106                "UDP is not allowed",
107            ));
108        }
109
110        Ok(())
111    }
112
113    pub(crate) fn check_allowed_tcp(&self) -> std::io::Result<()> {
114        if !self.tcp {
115            return Err(std::io::Error::new(
116                std::io::ErrorKind::PermissionDenied,
117                "TCP is not allowed",
118            ));
119        }
120
121        Ok(())
122    }
123}
124
125/// A check that will be called for each socket address that is used of whether the address is permitted.
126#[derive(Clone)]
127pub(crate) struct SocketAddrCheck(
128    Arc<
129        dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
130            + Send
131            + Sync,
132    >,
133);
134
135impl SocketAddrCheck {
136    /// A check that will be called for each socket address that is used.
137    ///
138    /// Returning `true` will permit socket connections to the `SocketAddr`,
139    /// while returning `false` will reject the connection.
140    pub(crate) fn new(
141        f: impl Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
142        + Send
143        + Sync
144        + 'static,
145    ) -> Self {
146        Self(Arc::new(f))
147    }
148
149    pub(crate) async fn check(
150        &self,
151        addr: SocketAddr,
152        reason: SocketAddrUse,
153    ) -> std::io::Result<()> {
154        if (self.0)(addr, reason).await {
155            Ok(())
156        } else {
157            Err(std::io::Error::new(
158                std::io::ErrorKind::PermissionDenied,
159                "An address was not permitted by the socket address check.",
160            ))
161        }
162    }
163}
164
165impl Deref for SocketAddrCheck {
166    type Target = dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
167        + Send
168        + Sync;
169
170    fn deref(&self) -> &Self::Target {
171        self.0.as_ref()
172    }
173}
174
175impl Default for SocketAddrCheck {
176    fn default() -> Self {
177        Self(Arc::new(|_, _| Box::pin(async { false })))
178    }
179}
180
181/// The reason what a socket address is being used for.
182#[derive(Clone, Copy, Debug)]
183pub enum SocketAddrUse {
184    /// Binding TCP socket
185    TcpBind,
186    /// Connecting TCP socket
187    TcpConnect,
188    /// Binding UDP socket
189    UdpBind,
190    /// Connecting UDP socket
191    UdpConnect,
192    /// Sending datagram on non-connected UDP socket
193    UdpOutgoingDatagram,
194}
195
196#[derive(Copy, Clone, Eq, PartialEq)]
197pub(crate) enum SocketAddressFamily {
198    Ipv4,
199    Ipv6,
200}