Skip to main content

wasmtime_wasi/sockets/
udp.rs

1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4    receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5    set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use io_lifetimes::AsSocketlike as _;
9use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
10use rustix::io::Errno;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tracing::debug;
14
15/// The state of a UDP socket.
16///
17/// This represents the various states a socket can be in during the
18/// activities of binding, and connecting.
19enum UdpState {
20    /// The initial state for a newly-created socket.
21    Default,
22
23    /// A `bind` operation has started but has yet to complete with
24    /// `finish_bind`.
25    BindStarted,
26
27    /// Binding finished via `finish_bind`. The socket has an address but
28    /// is not yet listening for connections.
29    Bound,
30
31    /// The socket is "connected" to a peer address.
32    #[cfg_attr(
33        not(feature = "p3"),
34        expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
35    )]
36    Connected(SocketAddr),
37}
38
39/// A host UDP socket, plus associated bookkeeping.
40///
41/// The inner state is wrapped in an Arc because the same underlying socket is
42/// used for implementing the stream types.
43pub struct UdpSocket {
44    socket: Arc<tokio::net::UdpSocket>,
45
46    /// The current state in the bind/connect progression.
47    udp_state: UdpState,
48
49    /// Socket address family.
50    family: SocketAddressFamily,
51
52    /// If set, use this custom check for addrs, otherwise use what's in
53    /// `WasiSocketsCtx`.
54    socket_addr_check: Option<SocketAddrCheck>,
55}
56
57impl UdpSocket {
58    /// Create a new socket in the given family.
59    pub(crate) fn new(cx: &WasiSocketsCtx, family: SocketAddressFamily) -> Result<Self, ErrorCode> {
60        cx.allowed_network_uses.check_allowed_udp()?;
61
62        let fd = udp_socket(family)?;
63
64        if family == SocketAddressFamily::Ipv6 {
65            rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
66        }
67
68        let socket = with_ambient_tokio_runtime(|| {
69            tokio::net::UdpSocket::try_from(unsafe {
70                std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
71            })
72        })?;
73
74        Ok(Self {
75            socket: Arc::new(socket),
76            udp_state: UdpState::Default,
77            family,
78            socket_addr_check: None,
79        })
80    }
81
82    pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
83        if !matches!(self.udp_state, UdpState::Default) {
84            return Err(ErrorCode::InvalidState);
85        }
86        if !is_valid_address_family(addr.ip(), self.family) {
87            return Err(ErrorCode::InvalidArgument);
88        }
89        udp_bind(&self.socket, addr)?;
90        self.udp_state = UdpState::BindStarted;
91        Ok(())
92    }
93
94    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
95        match self.udp_state {
96            UdpState::BindStarted => {
97                self.udp_state = UdpState::Bound;
98                Ok(())
99            }
100            _ => Err(ErrorCode::NotInProgress),
101        }
102    }
103
104    pub(crate) fn is_connected(&self) -> bool {
105        matches!(self.udp_state, UdpState::Connected(..))
106    }
107
108    pub(crate) fn is_bound(&self) -> bool {
109        matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
110    }
111
112    pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
113        if !self.is_connected() {
114            return Err(ErrorCode::InvalidState);
115        }
116        udp_disconnect(&self.socket)?;
117        self.udp_state = UdpState::Bound;
118        Ok(())
119    }
120
121    /// Connect using p2 semantics. (no implicit bind)
122    pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
123        match self.udp_state {
124            UdpState::Bound | UdpState::Connected(_) => {}
125            _ => return Err(ErrorCode::InvalidState),
126        }
127
128        self.connect_common(addr)
129    }
130
131    /// Connect using p3 semantics. (with implicit bind)
132    #[cfg(feature = "p3")]
133    pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134        match self.udp_state {
135            UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
136            _ => return Err(ErrorCode::InvalidState),
137        }
138
139        self.connect_common(addr)
140    }
141
142    fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
143        if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
144            return Err(ErrorCode::InvalidArgument);
145        }
146
147        match udp_connect(&self.socket, addr) {
148            Ok(()) => {
149                self.udp_state = UdpState::Connected(addr);
150                Ok(())
151            }
152            Err(e) => {
153                // Revert to a consistent state:
154                _ = udp_disconnect(&self.socket);
155                self.udp_state = UdpState::Bound;
156
157                Err(match e {
158                    Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
159                    Errno::INPROGRESS => {
160                        debug!("UDP connect returned EINPROGRESS, which should never happen");
161                        ErrorCode::Unknown
162                    }
163                    err => err.into(),
164                })
165            }
166        }
167    }
168
169    /// Send data using p3 semantics. (with implicit bind)
170    #[cfg(feature = "p3")]
171    pub(crate) fn send_p3(
172        &mut self,
173        buf: Vec<u8>,
174        addr: Option<SocketAddr>,
175    ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
176        enum Mode {
177            Send(Arc<tokio::net::UdpSocket>),
178            SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
179        }
180        let mut socket = match (&self.udp_state, addr) {
181            (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
182            (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
183            (UdpState::Default | UdpState::Bound, Some(addr)) => {
184                Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
185            }
186            (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
187            (UdpState::Connected(caddr), Some(addr)) => {
188                if addr == *caddr {
189                    Ok(Mode::Send(Arc::clone(&self.socket)))
190                } else {
191                    Err(ErrorCode::InvalidArgument)
192                }
193            }
194        };
195
196        // Send may be called without a prior bind or connect. In that case, the
197        // first send will automatically assign a free local port. This is
198        // normally performed by the OS itself. However, if the `send` syscall
199        // failed, we can't reliably know which state the socket is in at the
200        // kernel level and our own `udp_state` bookkeeping may have become
201        // out-of-sync.
202        // To avoid that, we perform the implicit bind ourselves here. This way,
203        // we always leave the socket in a consistent state: Bound.
204        if socket.is_ok()
205            && let UdpState::Default = self.udp_state
206        {
207            let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
208            match udp_bind(&self.socket, implicit_addr) {
209                Ok(()) => {
210                    self.udp_state = UdpState::Bound;
211                }
212                Err(e) => {
213                    socket = Err(e);
214                }
215            }
216        }
217
218        async move {
219            match socket? {
220                Mode::Send(socket) => send(&socket, &buf).await,
221                Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
222            }
223        }
224    }
225
226    /// Receive data using p3 semantics.
227    #[cfg(feature = "p3")]
228    pub(crate) fn receive_p3(
229        &self,
230    ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
231        enum Mode {
232            Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
233            RecvFrom(Arc<tokio::net::UdpSocket>),
234        }
235        let socket = match self.udp_state {
236            UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
237            UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
238            UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
239        };
240        async move {
241            let socket = socket?;
242            let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
243            let (n, addr) = match socket {
244                Mode::Recv(socket, addr) => {
245                    let n = socket.recv(&mut buf).await?;
246                    (n, addr)
247                }
248                Mode::RecvFrom(socket) => {
249                    let (n, addr) = socket.recv_from(&mut buf).await?;
250                    (n, addr)
251                }
252            };
253            buf.truncate(n);
254            Ok((buf, addr))
255        }
256    }
257
258    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
259        if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
260            return Err(ErrorCode::InvalidState);
261        }
262        let addr = self
263            .socket
264            .as_socketlike_view::<std::net::UdpSocket>()
265            .local_addr()?;
266        Ok(addr)
267    }
268
269    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
270        if !matches!(self.udp_state, UdpState::Connected(..)) {
271            return Err(ErrorCode::InvalidState);
272        }
273        let addr = self
274            .socket
275            .as_socketlike_view::<std::net::UdpSocket>()
276            .peer_addr()?;
277        Ok(addr)
278    }
279
280    pub(crate) fn address_family(&self) -> SocketAddressFamily {
281        self.family
282    }
283
284    pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
285        let n = get_unicast_hop_limit(&self.socket, self.family)?;
286        Ok(n)
287    }
288
289    pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
290        set_unicast_hop_limit(&self.socket, self.family, value)?;
291        Ok(())
292    }
293
294    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
295        let n = receive_buffer_size(&self.socket)?;
296        Ok(n)
297    }
298
299    pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
300        set_receive_buffer_size(&self.socket, value)?;
301        Ok(())
302    }
303
304    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
305        let n = send_buffer_size(&self.socket)?;
306        Ok(n)
307    }
308
309    pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
310        set_send_buffer_size(&self.socket, value)?;
311        Ok(())
312    }
313
314    pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
315        &self.socket
316    }
317
318    pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
319        self.socket_addr_check.as_ref()
320    }
321
322    pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
323        self.socket_addr_check = check;
324    }
325}
326
327#[cfg(feature = "p3")]
328async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
329    let n = socket.send(buf).await?;
330    // From Rust stdlib docs:
331    // > Note that the operating system may refuse buffers larger than 65507.
332    // > However, partial writes are not possible until buffer sizes above `i32::MAX`.
333    //
334    // For example, on Windows, at most `i32::MAX` bytes will be written
335    if n != buf.len() {
336        Err(ErrorCode::Unknown)
337    } else {
338        Ok(())
339    }
340}
341
342#[cfg(feature = "p3")]
343async fn send_to(
344    socket: &tokio::net::UdpSocket,
345    buf: &[u8],
346    addr: SocketAddr,
347) -> Result<(), ErrorCode> {
348    let n = socket.send_to(buf, addr).await?;
349    // See [`send`] documentation
350    if n != buf.len() {
351        Err(ErrorCode::Unknown)
352    } else {
353        Ok(())
354    }
355}