1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3 ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4 receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5 set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use cap_net_ext::AddressFamily;
9use io_lifetimes::AsSocketlike as _;
10use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11use rustix::io::Errno;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tracing::debug;
15
16enum UdpState {
21 Default,
23
24 BindStarted,
27
28 Bound,
31
32 #[cfg_attr(
34 not(feature = "p3"),
35 expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
36 )]
37 Connected(SocketAddr),
38}
39
40pub struct UdpSocket {
45 socket: Arc<tokio::net::UdpSocket>,
46
47 udp_state: UdpState,
49
50 family: SocketAddressFamily,
52
53 socket_addr_check: Option<SocketAddrCheck>,
56}
57
58impl UdpSocket {
59 pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> Result<Self, ErrorCode> {
61 cx.allowed_network_uses.check_allowed_udp()?;
62
63 let fd = udp_socket(family)?;
69
70 let socket_address_family = match family {
71 AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
72 AddressFamily::Ipv6 => {
73 rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
74 SocketAddressFamily::Ipv6
75 }
76 };
77
78 let socket = with_ambient_tokio_runtime(|| {
79 tokio::net::UdpSocket::try_from(unsafe {
80 std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
81 })
82 })?;
83
84 Ok(Self {
85 socket: Arc::new(socket),
86 udp_state: UdpState::Default,
87 family: socket_address_family,
88 socket_addr_check: None,
89 })
90 }
91
92 pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
93 if !matches!(self.udp_state, UdpState::Default) {
94 return Err(ErrorCode::InvalidState);
95 }
96 if !is_valid_address_family(addr.ip(), self.family) {
97 return Err(ErrorCode::InvalidArgument);
98 }
99 udp_bind(&self.socket, addr)?;
100 self.udp_state = UdpState::BindStarted;
101 Ok(())
102 }
103
104 pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
105 match self.udp_state {
106 UdpState::BindStarted => {
107 self.udp_state = UdpState::Bound;
108 Ok(())
109 }
110 _ => Err(ErrorCode::NotInProgress),
111 }
112 }
113
114 pub(crate) fn is_connected(&self) -> bool {
115 matches!(self.udp_state, UdpState::Connected(..))
116 }
117
118 pub(crate) fn is_bound(&self) -> bool {
119 matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
120 }
121
122 pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
123 if !self.is_connected() {
124 return Err(ErrorCode::InvalidState);
125 }
126 udp_disconnect(&self.socket)?;
127 self.udp_state = UdpState::Bound;
128 Ok(())
129 }
130
131 pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
133 match self.udp_state {
134 UdpState::Bound | UdpState::Connected(_) => {}
135 _ => return Err(ErrorCode::InvalidState),
136 }
137
138 self.connect_common(addr)
139 }
140
141 #[cfg(feature = "p3")]
143 pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
144 match self.udp_state {
145 UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
146 _ => return Err(ErrorCode::InvalidState),
147 }
148
149 self.connect_common(addr)
150 }
151
152 fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
153 if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
154 return Err(ErrorCode::InvalidArgument);
155 }
156
157 match udp_connect(&self.socket, addr) {
158 Ok(()) => {
159 self.udp_state = UdpState::Connected(addr);
160 Ok(())
161 }
162 Err(e) => {
163 _ = udp_disconnect(&self.socket);
165 self.udp_state = UdpState::Bound;
166
167 Err(match e {
168 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
170 debug!("UDP connect returned EINPROGRESS, which should never happen");
171 ErrorCode::Unknown
172 }
173 err => err.into(),
174 })
175 }
176 }
177 }
178
179 #[cfg(feature = "p3")]
181 pub(crate) fn send_p3(
182 &mut self,
183 buf: Vec<u8>,
184 addr: Option<SocketAddr>,
185 ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
186 enum Mode {
187 Send(Arc<tokio::net::UdpSocket>),
188 SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
189 }
190 let mut socket = match (&self.udp_state, addr) {
191 (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
192 (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
193 (UdpState::Default | UdpState::Bound, Some(addr)) => {
194 Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
195 }
196 (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
197 (UdpState::Connected(caddr), Some(addr)) => {
198 if addr == *caddr {
199 Ok(Mode::Send(Arc::clone(&self.socket)))
200 } else {
201 Err(ErrorCode::InvalidArgument)
202 }
203 }
204 };
205
206 if socket.is_ok()
215 && let UdpState::Default = self.udp_state
216 {
217 let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
218 match udp_bind(&self.socket, implicit_addr) {
219 Ok(()) => {
220 self.udp_state = UdpState::Bound;
221 }
222 Err(e) => {
223 socket = Err(e);
224 }
225 }
226 }
227
228 async move {
229 match socket? {
230 Mode::Send(socket) => send(&socket, &buf).await,
231 Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
232 }
233 }
234 }
235
236 #[cfg(feature = "p3")]
238 pub(crate) fn receive_p3(
239 &self,
240 ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
241 enum Mode {
242 Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
243 RecvFrom(Arc<tokio::net::UdpSocket>),
244 }
245 let socket = match self.udp_state {
246 UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
247 UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
248 UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
249 };
250 async move {
251 let socket = socket?;
252 let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
253 let (n, addr) = match socket {
254 Mode::Recv(socket, addr) => {
255 let n = socket.recv(&mut buf).await?;
256 (n, addr)
257 }
258 Mode::RecvFrom(socket) => {
259 let (n, addr) = socket.recv_from(&mut buf).await?;
260 (n, addr)
261 }
262 };
263 buf.truncate(n);
264 Ok((buf, addr))
265 }
266 }
267
268 pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
269 if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
270 return Err(ErrorCode::InvalidState);
271 }
272 let addr = self
273 .socket
274 .as_socketlike_view::<std::net::UdpSocket>()
275 .local_addr()?;
276 Ok(addr)
277 }
278
279 pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
280 if !matches!(self.udp_state, UdpState::Connected(..)) {
281 return Err(ErrorCode::InvalidState);
282 }
283 let addr = self
284 .socket
285 .as_socketlike_view::<std::net::UdpSocket>()
286 .peer_addr()?;
287 Ok(addr)
288 }
289
290 pub(crate) fn address_family(&self) -> SocketAddressFamily {
291 self.family
292 }
293
294 pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
295 let n = get_unicast_hop_limit(&self.socket, self.family)?;
296 Ok(n)
297 }
298
299 pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
300 set_unicast_hop_limit(&self.socket, self.family, value)?;
301 Ok(())
302 }
303
304 pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
305 let n = receive_buffer_size(&self.socket)?;
306 Ok(n)
307 }
308
309 pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
310 set_receive_buffer_size(&self.socket, value)?;
311 Ok(())
312 }
313
314 pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
315 let n = send_buffer_size(&self.socket)?;
316 Ok(n)
317 }
318
319 pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
320 set_send_buffer_size(&self.socket, value)?;
321 Ok(())
322 }
323
324 pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
325 &self.socket
326 }
327
328 pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
329 self.socket_addr_check.as_ref()
330 }
331
332 pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
333 self.socket_addr_check = check;
334 }
335}
336
337#[cfg(feature = "p3")]
338async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
339 let n = socket.send(buf).await?;
340 if n != buf.len() {
346 Err(ErrorCode::Unknown)
347 } else {
348 Ok(())
349 }
350}
351
352#[cfg(feature = "p3")]
353async fn send_to(
354 socket: &tokio::net::UdpSocket,
355 buf: &[u8],
356 addr: SocketAddr,
357) -> Result<(), ErrorCode> {
358 let n = socket.send_to(buf, addr).await?;
359 if n != buf.len() {
361 Err(ErrorCode::Unknown)
362 } else {
363 Ok(())
364 }
365}