Skip to main content

wasmtime_wasi/sockets/
udp.rs

1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4    receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5    set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use cap_net_ext::AddressFamily;
9use io_lifetimes::AsSocketlike as _;
10use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11use rustix::io::Errno;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tracing::debug;
15
16/// The state of a UDP socket.
17///
18/// This represents the various states a socket can be in during the
19/// activities of binding, and connecting.
20enum UdpState {
21    /// The initial state for a newly-created socket.
22    Default,
23
24    /// A `bind` operation has started but has yet to complete with
25    /// `finish_bind`.
26    BindStarted,
27
28    /// Binding finished via `finish_bind`. The socket has an address but
29    /// is not yet listening for connections.
30    Bound,
31
32    /// The socket is "connected" to a peer address.
33    #[cfg_attr(
34        not(feature = "p3"),
35        expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
36    )]
37    Connected(SocketAddr),
38}
39
40/// A host UDP socket, plus associated bookkeeping.
41///
42/// The inner state is wrapped in an Arc because the same underlying socket is
43/// used for implementing the stream types.
44pub struct UdpSocket {
45    socket: Arc<tokio::net::UdpSocket>,
46
47    /// The current state in the bind/connect progression.
48    udp_state: UdpState,
49
50    /// Socket address family.
51    family: SocketAddressFamily,
52
53    /// If set, use this custom check for addrs, otherwise use what's in
54    /// `WasiSocketsCtx`.
55    socket_addr_check: Option<SocketAddrCheck>,
56}
57
58impl UdpSocket {
59    /// Create a new socket in the given family.
60    pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
61        cx.allowed_network_uses.check_allowed_udp()?;
62
63        // Delegate socket creation to cap_net_ext. They handle a couple of things for us:
64        // - On Windows: call WSAStartup if not done before.
65        // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,
66        //   or afterwards using ioctl or fcntl. Exact method depends on the platform.
67
68        let fd = udp_socket(family)?;
69
70        let socket_address_family = match family {
71            AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
72            AddressFamily::Ipv6 => {
73                rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
74                SocketAddressFamily::Ipv6
75            }
76        };
77
78        let socket = with_ambient_tokio_runtime(|| {
79            tokio::net::UdpSocket::try_from(unsafe {
80                std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
81            })
82        })?;
83
84        Ok(Self {
85            socket: Arc::new(socket),
86            udp_state: UdpState::Default,
87            family: socket_address_family,
88            socket_addr_check: None,
89        })
90    }
91
92    pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
93        if !matches!(self.udp_state, UdpState::Default) {
94            return Err(ErrorCode::InvalidState);
95        }
96        if !is_valid_address_family(addr.ip(), self.family) {
97            return Err(ErrorCode::InvalidArgument);
98        }
99        udp_bind(&self.socket, addr)?;
100        self.udp_state = UdpState::BindStarted;
101        Ok(())
102    }
103
104    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
105        match self.udp_state {
106            UdpState::BindStarted => {
107                self.udp_state = UdpState::Bound;
108                Ok(())
109            }
110            _ => Err(ErrorCode::NotInProgress),
111        }
112    }
113
114    pub(crate) fn is_connected(&self) -> bool {
115        matches!(self.udp_state, UdpState::Connected(..))
116    }
117
118    pub(crate) fn is_bound(&self) -> bool {
119        matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
120    }
121
122    pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
123        if !self.is_connected() {
124            return Err(ErrorCode::InvalidState);
125        }
126        udp_disconnect(&self.socket)?;
127        self.udp_state = UdpState::Bound;
128        Ok(())
129    }
130
131    /// Connect using p2 semantics. (no implicit bind)
132    pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
133        match self.udp_state {
134            UdpState::Bound | UdpState::Connected(_) => {}
135            _ => return Err(ErrorCode::InvalidState),
136        }
137
138        self.connect_common(addr)
139    }
140
141    /// Connect using p3 semantics. (with implicit bind)
142    #[cfg(feature = "p3")]
143    pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
144        match self.udp_state {
145            UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
146            _ => return Err(ErrorCode::InvalidState),
147        }
148
149        self.connect_common(addr)
150    }
151
152    fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
153        if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
154            return Err(ErrorCode::InvalidArgument);
155        }
156
157        match udp_connect(&self.socket, addr) {
158            Ok(()) => {
159                self.udp_state = UdpState::Connected(addr);
160                Ok(())
161            }
162            Err(e) => {
163                // Revert to a consistent state:
164                _ = udp_disconnect(&self.socket);
165                self.udp_state = UdpState::Bound;
166
167                Err(match e {
168                    Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
169                    Errno::INPROGRESS => {
170                        debug!("UDP connect returned EINPROGRESS, which should never happen");
171                        ErrorCode::Unknown
172                    }
173                    err => err.into(),
174                })
175            }
176        }
177    }
178
179    /// Send data using p3 semantics. (with implicit bind)
180    #[cfg(feature = "p3")]
181    pub(crate) fn send_p3(
182        &mut self,
183        buf: Vec<u8>,
184        addr: Option<SocketAddr>,
185    ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
186        enum Mode {
187            Send(Arc<tokio::net::UdpSocket>),
188            SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
189        }
190        let mut socket = match (&self.udp_state, addr) {
191            (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
192            (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
193            (UdpState::Default | UdpState::Bound, Some(addr)) => {
194                Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
195            }
196            (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
197            (UdpState::Connected(caddr), Some(addr)) => {
198                if addr == *caddr {
199                    Ok(Mode::Send(Arc::clone(&self.socket)))
200                } else {
201                    Err(ErrorCode::InvalidArgument)
202                }
203            }
204        };
205
206        // Send may be called without a prior bind or connect. In that case, the
207        // first send will automatically assign a free local port. This is
208        // normally performed by the OS itself. However, if the `send` syscall
209        // failed, we can't reliably know which state the socket is in at the
210        // kernel level and our own `udp_state` bookkeeping may have become
211        // out-of-sync.
212        // To avoid that, we perform the implicit bind ourselves here. This way,
213        // we always leave the socket in a consistent state: Bound.
214        if socket.is_ok()
215            && let UdpState::Default = self.udp_state
216        {
217            let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
218            match udp_bind(&self.socket, implicit_addr) {
219                Ok(()) => {
220                    self.udp_state = UdpState::Bound;
221                }
222                Err(e) => {
223                    socket = Err(e);
224                }
225            }
226        }
227
228        async move {
229            match socket? {
230                Mode::Send(socket) => send(&socket, &buf).await,
231                Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
232            }
233        }
234    }
235
236    /// Receive data using p3 semantics.
237    #[cfg(feature = "p3")]
238    pub(crate) fn receive_p3(
239        &self,
240    ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
241        enum Mode {
242            Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
243            RecvFrom(Arc<tokio::net::UdpSocket>),
244        }
245        let socket = match self.udp_state {
246            UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
247            UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
248            UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
249        };
250        async move {
251            let socket = socket?;
252            let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
253            let (n, addr) = match socket {
254                Mode::Recv(socket, addr) => {
255                    let n = socket.recv(&mut buf).await?;
256                    (n, addr)
257                }
258                Mode::RecvFrom(socket) => {
259                    let (n, addr) = socket.recv_from(&mut buf).await?;
260                    (n, addr)
261                }
262            };
263            buf.truncate(n);
264            Ok((buf, addr))
265        }
266    }
267
268    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
269        if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
270            return Err(ErrorCode::InvalidState);
271        }
272        let addr = self
273            .socket
274            .as_socketlike_view::<std::net::UdpSocket>()
275            .local_addr()?;
276        Ok(addr)
277    }
278
279    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
280        if !matches!(self.udp_state, UdpState::Connected(..)) {
281            return Err(ErrorCode::InvalidState);
282        }
283        let addr = self
284            .socket
285            .as_socketlike_view::<std::net::UdpSocket>()
286            .peer_addr()?;
287        Ok(addr)
288    }
289
290    pub(crate) fn address_family(&self) -> SocketAddressFamily {
291        self.family
292    }
293
294    pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
295        let n = get_unicast_hop_limit(&self.socket, self.family)?;
296        Ok(n)
297    }
298
299    pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
300        set_unicast_hop_limit(&self.socket, self.family, value)?;
301        Ok(())
302    }
303
304    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
305        let n = receive_buffer_size(&self.socket)?;
306        Ok(n)
307    }
308
309    pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
310        set_receive_buffer_size(&self.socket, value)?;
311        Ok(())
312    }
313
314    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
315        let n = send_buffer_size(&self.socket)?;
316        Ok(n)
317    }
318
319    pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
320        set_send_buffer_size(&self.socket, value)?;
321        Ok(())
322    }
323
324    pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
325        &self.socket
326    }
327
328    pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
329        self.socket_addr_check.as_ref()
330    }
331
332    pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
333        self.socket_addr_check = check;
334    }
335}
336
337#[cfg(feature = "p3")]
338async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
339    let n = socket.send(buf).await?;
340    // From Rust stdlib docs:
341    // > Note that the operating system may refuse buffers larger than 65507.
342    // > However, partial writes are not possible until buffer sizes above `i32::MAX`.
343    //
344    // For example, on Windows, at most `i32::MAX` bytes will be written
345    if n != buf.len() {
346        Err(ErrorCode::Unknown)
347    } else {
348        Ok(())
349    }
350}
351
352#[cfg(feature = "p3")]
353async fn send_to(
354    socket: &tokio::net::UdpSocket,
355    buf: &[u8],
356    addr: SocketAddr,
357) -> Result<(), ErrorCode> {
358    let n = socket.send_to(buf, addr).await?;
359    // See [`send`] documentation
360    if n != buf.len() {
361        Err(ErrorCode::Unknown)
362    } else {
363        Ok(())
364    }
365}