1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3 ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4 receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5 set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use cap_net_ext::AddressFamily;
9use io_lifetimes::AsSocketlike as _;
10use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
11use rustix::io::Errno;
12use rustix::net::connect;
13use std::io;
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tracing::debug;
17
18enum UdpState {
23 Default,
25
26 BindStarted,
28
29 Bound,
32
33 #[cfg_attr(
35 not(feature = "p3"),
36 expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
37 )]
38 Connected(SocketAddr),
39}
40
41pub struct UdpSocket {
46 socket: Arc<tokio::net::UdpSocket>,
47
48 udp_state: UdpState,
50
51 family: SocketAddressFamily,
53
54 socket_addr_check: Option<SocketAddrCheck>,
57}
58
59impl UdpSocket {
60 pub(crate) fn new(cx: &WasiSocketsCtx, family: AddressFamily) -> io::Result<Self> {
62 cx.allowed_network_uses.check_allowed_udp()?;
63
64 let fd = udp_socket(family)?;
70
71 let socket_address_family = match family {
72 AddressFamily::Ipv4 => SocketAddressFamily::Ipv4,
73 AddressFamily::Ipv6 => {
74 rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
75 SocketAddressFamily::Ipv6
76 }
77 };
78
79 let socket = with_ambient_tokio_runtime(|| {
80 tokio::net::UdpSocket::try_from(unsafe {
81 std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
82 })
83 })?;
84
85 Ok(Self {
86 socket: Arc::new(socket),
87 udp_state: UdpState::Default,
88 family: socket_address_family,
89 socket_addr_check: None,
90 })
91 }
92
93 pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
94 if !matches!(self.udp_state, UdpState::Default) {
95 return Err(ErrorCode::InvalidState);
96 }
97 if !is_valid_address_family(addr.ip(), self.family) {
98 return Err(ErrorCode::InvalidArgument);
99 }
100 udp_bind(&self.socket, addr)?;
101 self.udp_state = UdpState::BindStarted;
102 Ok(())
103 }
104
105 pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
106 match self.udp_state {
107 UdpState::BindStarted => {
108 self.udp_state = UdpState::Bound;
109 Ok(())
110 }
111 _ => Err(ErrorCode::NotInProgress),
112 }
113 }
114
115 pub(crate) fn is_connected(&self) -> bool {
116 matches!(self.udp_state, UdpState::Connected(..))
117 }
118
119 pub(crate) fn is_bound(&self) -> bool {
120 matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
121 }
122
123 pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
124 if !self.is_connected() {
125 return Err(ErrorCode::InvalidState);
126 }
127 udp_disconnect(&self.socket)?;
128 self.udp_state = UdpState::Bound;
129 Ok(())
130 }
131
132 pub(crate) fn connect(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
133 if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
134 return Err(ErrorCode::InvalidArgument);
135 }
136
137 match self.udp_state {
138 UdpState::Bound | UdpState::Connected(_) => {}
139 _ => return Err(ErrorCode::InvalidState),
140 }
141
142 if let UdpState::Connected(..) = self.udp_state {
150 udp_disconnect(&self.socket)?;
151 self.udp_state = UdpState::Bound;
152 }
153 connect(&self.socket, &addr).map_err(|error| match error {
155 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
157 debug!("UDP connect returned EINPROGRESS, which should never happen");
158 ErrorCode::Unknown
159 }
160 err => err.into(),
161 })?;
162 self.udp_state = UdpState::Connected(addr);
163 Ok(())
164 }
165
166 #[cfg(feature = "p3")]
167 pub(crate) fn send(&self, buf: Vec<u8>) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
168 let socket = if let UdpState::Connected(..) = self.udp_state {
169 Ok(Arc::clone(&self.socket))
170 } else {
171 Err(ErrorCode::InvalidArgument)
172 };
173 async move {
174 let socket = socket?;
175 send(&socket, &buf).await
176 }
177 }
178
179 #[cfg(feature = "p3")]
180 pub(crate) fn send_to(
181 &self,
182 buf: Vec<u8>,
183 addr: SocketAddr,
184 ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
185 enum Mode {
186 Send(Arc<tokio::net::UdpSocket>),
187 SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
188 }
189 let socket = match &self.udp_state {
190 UdpState::BindStarted => Err(ErrorCode::InvalidState),
191 UdpState::Default | UdpState::Bound => Ok(Mode::SendTo(Arc::clone(&self.socket), addr)),
192 UdpState::Connected(caddr) if addr == *caddr => {
193 Ok(Mode::Send(Arc::clone(&self.socket)))
194 }
195 UdpState::Connected(..) => Err(ErrorCode::InvalidArgument),
196 };
197 async move {
198 match socket? {
199 Mode::Send(socket) => send(&socket, &buf).await,
200 Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
201 }
202 }
203 }
204
205 #[cfg(feature = "p3")]
206 pub(crate) fn receive(
207 &self,
208 ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
209 enum Mode {
210 Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
211 RecvFrom(Arc<tokio::net::UdpSocket>),
212 }
213 let socket = match self.udp_state {
214 UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
215 UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
216 UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr.into())),
217 };
218 async move {
219 let socket = socket?;
220 let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
221 let (n, addr) = match socket {
222 Mode::Recv(socket, addr) => {
223 let n = socket.recv(&mut buf).await?;
224 (n, addr)
225 }
226 Mode::RecvFrom(socket) => {
227 let (n, addr) = socket.recv_from(&mut buf).await?;
228 (n, addr)
229 }
230 };
231 buf.truncate(n);
232 Ok((buf, addr))
233 }
234 }
235
236 pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
237 if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
238 return Err(ErrorCode::InvalidState);
239 }
240 let addr = self
241 .socket
242 .as_socketlike_view::<std::net::UdpSocket>()
243 .local_addr()?;
244 Ok(addr)
245 }
246
247 pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
248 if !matches!(self.udp_state, UdpState::Connected(..)) {
249 return Err(ErrorCode::InvalidState);
250 }
251 let addr = self
252 .socket
253 .as_socketlike_view::<std::net::UdpSocket>()
254 .peer_addr()?;
255 Ok(addr)
256 }
257
258 pub(crate) fn address_family(&self) -> SocketAddressFamily {
259 self.family
260 }
261
262 pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
263 let n = get_unicast_hop_limit(&self.socket, self.family)?;
264 Ok(n)
265 }
266
267 pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
268 set_unicast_hop_limit(&self.socket, self.family, value)?;
269 Ok(())
270 }
271
272 pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
273 let n = receive_buffer_size(&self.socket)?;
274 Ok(n)
275 }
276
277 pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
278 set_receive_buffer_size(&self.socket, value)?;
279 Ok(())
280 }
281
282 pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
283 let n = send_buffer_size(&self.socket)?;
284 Ok(n)
285 }
286
287 pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
288 set_send_buffer_size(&self.socket, value)?;
289 Ok(())
290 }
291
292 pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
293 &self.socket
294 }
295
296 pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
297 self.socket_addr_check.as_ref()
298 }
299
300 pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
301 self.socket_addr_check = check;
302 }
303}
304
305#[cfg(feature = "p3")]
306async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
307 let n = socket.send(buf).await?;
308 if n != buf.len() {
314 Err(ErrorCode::Unknown)
315 } else {
316 Ok(())
317 }
318}
319
320#[cfg(feature = "p3")]
321async fn send_to(
322 socket: &tokio::net::UdpSocket,
323 buf: &[u8],
324 addr: SocketAddr,
325) -> Result<(), ErrorCode> {
326 let n = socket.send_to(buf, addr).await?;
327 if n != buf.len() {
329 Err(ErrorCode::Unknown)
330 } else {
331 Ok(())
332 }
333}