1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3 ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4 receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5 set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use cap_net_ext::AddressFamily;
9use io_lifetimes::AsSocketlike as _;
10use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11use rustix::io::Errno;
12use rustix::net::connect;
13use std::net::SocketAddr;
14use std::sync::Arc;
15use tracing::debug;
16
17enum UdpState {
22 Default,
24
25 BindStarted,
28
29 Bound,
32
33 #[cfg_attr(
35 not(feature = "p3"),
36 expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
37 )]
38 Connected(SocketAddr),
39}
40
41pub struct UdpSocket {
46 socket: Arc<tokio::net::UdpSocket>,
47
48 udp_state: UdpState,
50
51 family: SocketAddressFamily,
53
54 socket_addr_check: Option<SocketAddrCheck>,
57}
58
59impl UdpSocket {
60 pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
62 cx.allowed_network_uses.check_allowed_udp()?;
63
64 let fd = udp_socket(family)?;
70
71 let socket_address_family = match family {
72 AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
73 AddressFamily::Ipv6 => {
74 rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
75 SocketAddressFamily::Ipv6
76 }
77 };
78
79 let socket = with_ambient_tokio_runtime(|| {
80 tokio::net::UdpSocket::try_from(unsafe {
81 std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
82 })
83 })?;
84
85 Ok(Self {
86 socket: Arc::new(socket),
87 udp_state: UdpState::Default,
88 family: socket_address_family,
89 socket_addr_check: None,
90 })
91 }
92
93 pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
94 if !matches!(self.udp_state, UdpState::Default) {
95 return Err(ErrorCode::InvalidState);
96 }
97 if !is_valid_address_family(addr.ip(), self.family) {
98 return Err(ErrorCode::InvalidArgument);
99 }
100 udp_bind(&self.socket, addr)?;
101 self.udp_state = UdpState::BindStarted;
102 Ok(())
103 }
104
105 pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
106 match self.udp_state {
107 UdpState::BindStarted => {
108 self.udp_state = UdpState::Bound;
109 Ok(())
110 }
111 _ => Err(ErrorCode::NotInProgress),
112 }
113 }
114
115 pub(crate) fn is_connected(&self) -> bool {
116 matches!(self.udp_state, UdpState::Connected(..))
117 }
118
119 pub(crate) fn is_bound(&self) -> bool {
120 matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
121 }
122
123 pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
124 if !self.is_connected() {
125 return Err(ErrorCode::InvalidState);
126 }
127 udp_disconnect(&self.socket)?;
128 self.udp_state = UdpState::Bound;
129 Ok(())
130 }
131
132 pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134 match self.udp_state {
135 UdpState::Bound | UdpState::Connected(_) => {}
136 _ => return Err(ErrorCode::InvalidState),
137 }
138
139 self.connect_common(addr)
140 }
141
142 #[cfg(feature = "p3")]
144 pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
145 match self.udp_state {
146 UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
147 _ => return Err(ErrorCode::InvalidState),
148 }
149
150 self.connect_common(addr)
151 }
152
153 fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
154 if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
155 return Err(ErrorCode::InvalidArgument);
156 }
157
158 if let UdpState::Connected(..) = self.udp_state {
166 udp_disconnect(&self.socket)?;
167 self.udp_state = UdpState::Bound;
168 }
169 connect(&self.socket, &addr).map_err(|error| match error {
171 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
173 debug!("UDP connect returned EINPROGRESS, which should never happen");
174 ErrorCode::Unknown
175 }
176 err => err.into(),
177 })?;
178 self.udp_state = UdpState::Connected(addr);
179 Ok(())
180 }
181
182 #[cfg(feature = "p3")]
184 pub(crate) fn send_p3(
185 &mut self,
186 buf: Vec<u8>,
187 addr: Option<SocketAddr>,
188 ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
189 enum Mode {
190 Send(Arc<tokio::net::UdpSocket>),
191 SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
192 }
193 let mut socket = match (&self.udp_state, addr) {
194 (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
195 (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
196 (UdpState::Default | UdpState::Bound, Some(addr)) => {
197 Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
198 }
199 (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
200 (UdpState::Connected(caddr), Some(addr)) => {
201 if addr == *caddr {
202 Ok(Mode::Send(Arc::clone(&self.socket)))
203 } else {
204 Err(ErrorCode::InvalidArgument)
205 }
206 }
207 };
208
209 if socket.is_ok()
218 && let UdpState::Default = self.udp_state
219 {
220 let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
221 match udp_bind(&self.socket, implicit_addr) {
222 Ok(()) => {
223 self.udp_state = UdpState::Bound;
224 }
225 Err(e) => {
226 socket = Err(e);
227 }
228 }
229 }
230
231 async move {
232 match socket? {
233 Mode::Send(socket) => send(&socket, &buf).await,
234 Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
235 }
236 }
237 }
238
239 #[cfg(feature = "p3")]
241 pub(crate) fn receive_p3(
242 &self,
243 ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
244 enum Mode {
245 Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
246 RecvFrom(Arc<tokio::net::UdpSocket>),
247 }
248 let socket = match self.udp_state {
249 UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
250 UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
251 UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
252 };
253 async move {
254 let socket = socket?;
255 let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
256 let (n, addr) = match socket {
257 Mode::Recv(socket, addr) => {
258 let n = socket.recv(&mut buf).await?;
259 (n, addr)
260 }
261 Mode::RecvFrom(socket) => {
262 let (n, addr) = socket.recv_from(&mut buf).await?;
263 (n, addr)
264 }
265 };
266 buf.truncate(n);
267 Ok((buf, addr))
268 }
269 }
270
271 pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
272 if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
273 return Err(ErrorCode::InvalidState);
274 }
275 let addr = self
276 .socket
277 .as_socketlike_view::<std::net::UdpSocket>()
278 .local_addr()?;
279 Ok(addr)
280 }
281
282 pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
283 if !matches!(self.udp_state, UdpState::Connected(..)) {
284 return Err(ErrorCode::InvalidState);
285 }
286 let addr = self
287 .socket
288 .as_socketlike_view::<std::net::UdpSocket>()
289 .peer_addr()?;
290 Ok(addr)
291 }
292
293 pub(crate) fn address_family(&self) -> SocketAddressFamily {
294 self.family
295 }
296
297 pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
298 let n = get_unicast_hop_limit(&self.socket, self.family)?;
299 Ok(n)
300 }
301
302 pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
303 set_unicast_hop_limit(&self.socket, self.family, value)?;
304 Ok(())
305 }
306
307 pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
308 let n = receive_buffer_size(&self.socket)?;
309 Ok(n)
310 }
311
312 pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
313 set_receive_buffer_size(&self.socket, value)?;
314 Ok(())
315 }
316
317 pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
318 let n = send_buffer_size(&self.socket)?;
319 Ok(n)
320 }
321
322 pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
323 set_send_buffer_size(&self.socket, value)?;
324 Ok(())
325 }
326
327 pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
328 &self.socket
329 }
330
331 pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
332 self.socket_addr_check.as_ref()
333 }
334
335 pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
336 self.socket_addr_check = check;
337 }
338}
339
340#[cfg(feature = "p3")]
341async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
342 let n = socket.send(buf).await?;
343 if n != buf.len() {
349 Err(ErrorCode::Unknown)
350 } else {
351 Ok(())
352 }
353}
354
355#[cfg(feature = "p3")]
356async fn send_to(
357 socket: &tokio::net::UdpSocket,
358 buf: &[u8],
359 addr: SocketAddr,
360) -> Result<(), ErrorCode> {
361 let n = socket.send_to(buf, addr).await?;
362 if n != buf.len() {
364 Err(ErrorCode::Unknown)
365 } else {
366 Ok(())
367 }
368}