wasmtime_wasi/sockets/
udp.rs

1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4    receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5    set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use cap_net_ext::AddressFamily;
9use io_lifetimes::AsSocketlike as _;
10use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11use rustix::io::Errno;
12use rustix::net::connect;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tracing::debug;
16
17/// The state of a UDP socket.
18///
19/// This represents the various states a socket can be in during the
20/// activities of binding, and connecting.
21enum UdpState {
22    /// The initial state for a newly-created socket.
23    Default,
24
25    /// A `bind` operation has started but has yet to complete with
26    /// `finish_bind`.
27    BindStarted,
28
29    /// Binding finished via `finish_bind`. The socket has an address but
30    /// is not yet listening for connections.
31    Bound,
32
33    /// The socket is "connected" to a peer address.
34    #[cfg_attr(
35        not(feature = "p3"),
36        expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
37    )]
38    Connected(SocketAddr),
39}
40
41/// A host UDP socket, plus associated bookkeeping.
42///
43/// The inner state is wrapped in an Arc because the same underlying socket is
44/// used for implementing the stream types.
45pub struct UdpSocket {
46    socket: Arc<tokio::net::UdpSocket>,
47
48    /// The current state in the bind/connect progression.
49    udp_state: UdpState,
50
51    /// Socket address family.
52    family: SocketAddressFamily,
53
54    /// If set, use this custom check for addrs, otherwise use what's in
55    /// `WasiSocketsCtx`.
56    socket_addr_check: Option<SocketAddrCheck>,
57}
58
59impl UdpSocket {
60    /// Create a new socket in the given family.
61    pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
62        cx.allowed_network_uses.check_allowed_udp()?;
63
64        // Delegate socket creation to cap_net_ext. They handle a couple of things for us:
65        // - On Windows: call WSAStartup if not done before.
66        // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,
67        //   or afterwards using ioctl or fcntl. Exact method depends on the platform.
68
69        let fd = udp_socket(family)?;
70
71        let socket_address_family = match family {
72            AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
73            AddressFamily::Ipv6 => {
74                rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
75                SocketAddressFamily::Ipv6
76            }
77        };
78
79        let socket = with_ambient_tokio_runtime(|| {
80            tokio::net::UdpSocket::try_from(unsafe {
81                std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
82            })
83        })?;
84
85        Ok(Self {
86            socket: Arc::new(socket),
87            udp_state: UdpState::Default,
88            family: socket_address_family,
89            socket_addr_check: None,
90        })
91    }
92
93    pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
94        if !matches!(self.udp_state, UdpState::Default) {
95            return Err(ErrorCode::InvalidState);
96        }
97        if !is_valid_address_family(addr.ip(), self.family) {
98            return Err(ErrorCode::InvalidArgument);
99        }
100        udp_bind(&self.socket, addr)?;
101        self.udp_state = UdpState::BindStarted;
102        Ok(())
103    }
104
105    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
106        match self.udp_state {
107            UdpState::BindStarted => {
108                self.udp_state = UdpState::Bound;
109                Ok(())
110            }
111            _ => Err(ErrorCode::NotInProgress),
112        }
113    }
114
115    pub(crate) fn is_connected(&self) -> bool {
116        matches!(self.udp_state, UdpState::Connected(..))
117    }
118
119    pub(crate) fn is_bound(&self) -> bool {
120        matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
121    }
122
123    pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
124        if !self.is_connected() {
125            return Err(ErrorCode::InvalidState);
126        }
127        udp_disconnect(&self.socket)?;
128        self.udp_state = UdpState::Bound;
129        Ok(())
130    }
131
132    /// Connect using p2 semantics. (no implicit bind)
133    pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134        match self.udp_state {
135            UdpState::Bound | UdpState::Connected(_) => {}
136            _ => return Err(ErrorCode::InvalidState),
137        }
138
139        self.connect_common(addr)
140    }
141
142    /// Connect using p3 semantics. (with implicit bind)
143    #[cfg(feature = "p3")]
144    pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
145        match self.udp_state {
146            UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
147            _ => return Err(ErrorCode::InvalidState),
148        }
149
150        self.connect_common(addr)
151    }
152
153    fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
154        if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
155            return Err(ErrorCode::InvalidArgument);
156        }
157
158        // We disconnect & (re)connect in two distinct steps for two reasons:
159        // - To leave our socket instance in a consistent state in case the
160        //   connect fails.
161        // - When reconnecting to a different address, Linux sometimes fails
162        //   if there isn't a disconnect in between.
163
164        // Step #1: Disconnect
165        if let UdpState::Connected(..) = self.udp_state {
166            udp_disconnect(&self.socket)?;
167            self.udp_state = UdpState::Bound;
168        }
169        // Step #2: (Re)connect
170        connect(&self.socket, &addr).map_err(|error| match error {
171            Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
172            Errno::INPROGRESS => {
173                debug!("UDP connect returned EINPROGRESS, which should never happen");
174                ErrorCode::Unknown
175            }
176            err => err.into(),
177        })?;
178        self.udp_state = UdpState::Connected(addr);
179        Ok(())
180    }
181
182    /// Send data using p3 semantics. (with implicit bind)
183    #[cfg(feature = "p3")]
184    pub(crate) fn send_p3(
185        &mut self,
186        buf: Vec<u8>,
187        addr: Option<SocketAddr>,
188    ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
189        enum Mode {
190            Send(Arc<tokio::net::UdpSocket>),
191            SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
192        }
193        let mut socket = match (&self.udp_state, addr) {
194            (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
195            (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
196            (UdpState::Default | UdpState::Bound, Some(addr)) => {
197                Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
198            }
199            (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
200            (UdpState::Connected(caddr), Some(addr)) => {
201                if addr == *caddr {
202                    Ok(Mode::Send(Arc::clone(&self.socket)))
203                } else {
204                    Err(ErrorCode::InvalidArgument)
205                }
206            }
207        };
208
209        // Send may be called without a prior bind or connect. In that case, the
210        // first send will automatically assign a free local port. This is
211        // normally performed by the OS itself. However, if the `send` syscall
212        // failed, we can't reliably know which state the socket is in at the
213        // kernel level and our own `udp_state` bookkeeping may have become
214        // out-of-sync.
215        // To avoid that, we perform the implicit bind ourselves here. This way,
216        // we always leave the socket in a consistent state: Bound.
217        if socket.is_ok()
218            && let UdpState::Default = self.udp_state
219        {
220            let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
221            match udp_bind(&self.socket, implicit_addr) {
222                Ok(()) => {
223                    self.udp_state = UdpState::Bound;
224                }
225                Err(e) => {
226                    socket = Err(e);
227                }
228            }
229        }
230
231        async move {
232            match socket? {
233                Mode::Send(socket) => send(&socket, &buf).await,
234                Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
235            }
236        }
237    }
238
239    /// Receive data using p3 semantics.
240    #[cfg(feature = "p3")]
241    pub(crate) fn receive_p3(
242        &self,
243    ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
244        enum Mode {
245            Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
246            RecvFrom(Arc<tokio::net::UdpSocket>),
247        }
248        let socket = match self.udp_state {
249            UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
250            UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
251            UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
252        };
253        async move {
254            let socket = socket?;
255            let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
256            let (n, addr) = match socket {
257                Mode::Recv(socket, addr) => {
258                    let n = socket.recv(&mut buf).await?;
259                    (n, addr)
260                }
261                Mode::RecvFrom(socket) => {
262                    let (n, addr) = socket.recv_from(&mut buf).await?;
263                    (n, addr)
264                }
265            };
266            buf.truncate(n);
267            Ok((buf, addr))
268        }
269    }
270
271    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
272        if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
273            return Err(ErrorCode::InvalidState);
274        }
275        let addr = self
276            .socket
277            .as_socketlike_view::<std::net::UdpSocket>()
278            .local_addr()?;
279        Ok(addr)
280    }
281
282    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
283        if !matches!(self.udp_state, UdpState::Connected(..)) {
284            return Err(ErrorCode::InvalidState);
285        }
286        let addr = self
287            .socket
288            .as_socketlike_view::<std::net::UdpSocket>()
289            .peer_addr()?;
290        Ok(addr)
291    }
292
293    pub(crate) fn address_family(&self) -> SocketAddressFamily {
294        self.family
295    }
296
297    pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
298        let n = get_unicast_hop_limit(&self.socket, self.family)?;
299        Ok(n)
300    }
301
302    pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
303        set_unicast_hop_limit(&self.socket, self.family, value)?;
304        Ok(())
305    }
306
307    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
308        let n = receive_buffer_size(&self.socket)?;
309        Ok(n)
310    }
311
312    pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
313        set_receive_buffer_size(&self.socket, value)?;
314        Ok(())
315    }
316
317    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
318        let n = send_buffer_size(&self.socket)?;
319        Ok(n)
320    }
321
322    pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
323        set_send_buffer_size(&self.socket, value)?;
324        Ok(())
325    }
326
327    pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
328        &self.socket
329    }
330
331    pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
332        self.socket_addr_check.as_ref()
333    }
334
335    pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
336        self.socket_addr_check = check;
337    }
338}
339
340#[cfg(feature = "p3")]
341async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
342    let n = socket.send(buf).await?;
343    // From Rust stdlib docs:
344    // > Note that the operating system may refuse buffers larger than 65507.
345    // > However, partial writes are not possible until buffer sizes above `i32::MAX`.
346    //
347    // For example, on Windows, at most `i32::MAX` bytes will be written
348    if n != buf.len() {
349        Err(ErrorCode::Unknown)
350    } else {
351        Ok(())
352    }
353}
354
355#[cfg(feature = "p3")]
356async fn send_to(
357    socket: &tokio::net::UdpSocket,
358    buf: &[u8],
359    addr: SocketAddr,
360) -> Result<(), ErrorCode> {
361    let n = socket.send_to(buf, addr).await?;
362    // See [`send`] documentation
363    if n != buf.len() {
364        Err(ErrorCode::Unknown)
365    } else {
366        Ok(())
367    }
368}