wasmtime_wasi/sockets/
mod.rs

1use core::future::Future;
2use core::ops::Deref;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6use wasmtime::component::{HasData, ResourceTable};
7
8mod tcp;
9mod udp;
10pub(crate) mod util;
11
12#[cfg(feature = "p3")]
13pub(crate) use tcp::NonInheritedOptions;
14pub use tcp::TcpSocket;
15pub use udp::UdpSocket;
16
17pub(crate) struct WasiSockets;
18
19impl HasData for WasiSockets {
20    type Data<'a> = WasiSocketsCtxView<'a>;
21}
22
23/// Value taken from rust std library.
24pub(crate) const DEFAULT_TCP_BACKLOG: u32 = 128;
25
26/// Theoretical maximum byte size of a UDP datagram, the real limit is lower,
27/// but we do not account for e.g. the transport layer here for simplicity.
28/// In practice, datagrams are typically less than 1500 bytes.
29pub(crate) const MAX_UDP_DATAGRAM_SIZE: usize = u16::MAX as usize;
30
31#[derive(Clone, Default)]
32pub struct WasiSocketsCtx {
33    pub(crate) socket_addr_check: SocketAddrCheck,
34    pub(crate) allowed_network_uses: AllowedNetworkUses,
35}
36
37pub struct WasiSocketsCtxView<'a> {
38    pub ctx: &'a mut WasiSocketsCtx,
39    pub table: &'a mut ResourceTable,
40}
41
42pub trait WasiSocketsView: Send {
43    fn sockets(&mut self) -> WasiSocketsCtxView<'_>;
44}
45
46#[derive(Copy, Clone)]
47pub(crate) struct AllowedNetworkUses {
48    pub(crate) ip_name_lookup: bool,
49    pub(crate) udp: bool,
50    pub(crate) tcp: bool,
51}
52
53impl Default for AllowedNetworkUses {
54    fn default() -> Self {
55        Self {
56            ip_name_lookup: false,
57            udp: true,
58            tcp: true,
59        }
60    }
61}
62
63impl AllowedNetworkUses {
64    pub(crate) fn check_allowed_udp(&self) -> std::io::Result<()> {
65        if !self.udp {
66            return Err(std::io::Error::new(
67                std::io::ErrorKind::PermissionDenied,
68                "UDP is not allowed",
69            ));
70        }
71
72        Ok(())
73    }
74
75    pub(crate) fn check_allowed_tcp(&self) -> std::io::Result<()> {
76        if !self.tcp {
77            return Err(std::io::Error::new(
78                std::io::ErrorKind::PermissionDenied,
79                "TCP is not allowed",
80            ));
81        }
82
83        Ok(())
84    }
85}
86
87/// A check that will be called for each socket address that is used of whether the address is permitted.
88#[derive(Clone)]
89pub(crate) struct SocketAddrCheck(
90    Arc<
91        dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
92            + Send
93            + Sync,
94    >,
95);
96
97impl SocketAddrCheck {
98    /// A check that will be called for each socket address that is used.
99    ///
100    /// Returning `true` will permit socket connections to the `SocketAddr`,
101    /// while returning `false` will reject the connection.
102    pub(crate) fn new(
103        f: impl Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
104        + Send
105        + Sync
106        + 'static,
107    ) -> Self {
108        Self(Arc::new(f))
109    }
110
111    pub(crate) async fn check(
112        &self,
113        addr: SocketAddr,
114        reason: SocketAddrUse,
115    ) -> std::io::Result<()> {
116        if (self.0)(addr, reason).await {
117            Ok(())
118        } else {
119            Err(std::io::Error::new(
120                std::io::ErrorKind::PermissionDenied,
121                "An address was not permitted by the socket address check.",
122            ))
123        }
124    }
125}
126
127impl Deref for SocketAddrCheck {
128    type Target = dyn Fn(SocketAddr, SocketAddrUse) -> Pin<Box<dyn Future<Output = bool> + Send + Sync>>
129        + Send
130        + Sync;
131
132    fn deref(&self) -> &Self::Target {
133        self.0.as_ref()
134    }
135}
136
137impl Default for SocketAddrCheck {
138    fn default() -> Self {
139        Self(Arc::new(|_, _| Box::pin(async { false })))
140    }
141}
142
143/// The reason what a socket address is being used for.
144#[derive(Clone, Copy, Debug)]
145pub enum SocketAddrUse {
146    /// Binding TCP socket
147    TcpBind,
148    /// Connecting TCP socket
149    TcpConnect,
150    /// Binding UDP socket
151    UdpBind,
152    /// Connecting UDP socket
153    UdpConnect,
154    /// Sending datagram on non-connected UDP socket
155    UdpOutgoingDatagram,
156}
157
158#[derive(Copy, Clone, Eq, PartialEq)]
159pub(crate) enum SocketAddressFamily {
160    Ipv4,
161    Ipv6,
162}