1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3 ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4 receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5 set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use io_lifetimes::AsSocketlike as _;
9use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
10use rustix::io::Errno;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tracing::debug;
14
15enum UdpState {
20 Default,
22
23 BindStarted,
26
27 Bound,
30
31 #[cfg_attr(
33 not(feature = "p3"),
34 expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
35 )]
36 Connected(SocketAddr),
37}
38
39pub struct UdpSocket {
44 socket: Arc<tokio::net::UdpSocket>,
45
46 udp_state: UdpState,
48
49 family: SocketAddressFamily,
51
52 socket_addr_check: Option<SocketAddrCheck>,
55}
56
57impl UdpSocket {
58 pub(crate) fn new(cx: &WasiSocketsCtx, family: SocketAddressFamily) -> Result<Self, ErrorCode> {
60 cx.allowed_network_uses.check_allowed_udp()?;
61
62 let fd = udp_socket(family)?;
63
64 if family == SocketAddressFamily::Ipv6 {
65 rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
66 }
67
68 let socket = with_ambient_tokio_runtime(|| {
69 tokio::net::UdpSocket::try_from(unsafe {
70 std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
71 })
72 })?;
73
74 Ok(Self {
75 socket: Arc::new(socket),
76 udp_state: UdpState::Default,
77 family,
78 socket_addr_check: None,
79 })
80 }
81
82 pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
83 if !matches!(self.udp_state, UdpState::Default) {
84 return Err(ErrorCode::InvalidState);
85 }
86 if !is_valid_address_family(addr.ip(), self.family) {
87 return Err(ErrorCode::InvalidArgument);
88 }
89 udp_bind(&self.socket, addr)?;
90 self.udp_state = UdpState::BindStarted;
91 Ok(())
92 }
93
94 pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
95 match self.udp_state {
96 UdpState::BindStarted => {
97 self.udp_state = UdpState::Bound;
98 Ok(())
99 }
100 _ => Err(ErrorCode::NotInProgress),
101 }
102 }
103
104 pub(crate) fn is_connected(&self) -> bool {
105 matches!(self.udp_state, UdpState::Connected(..))
106 }
107
108 pub(crate) fn is_bound(&self) -> bool {
109 matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
110 }
111
112 pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
113 if !self.is_connected() {
114 return Err(ErrorCode::InvalidState);
115 }
116 udp_disconnect(&self.socket)?;
117 self.udp_state = UdpState::Bound;
118 Ok(())
119 }
120
121 pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
123 match self.udp_state {
124 UdpState::Bound | UdpState::Connected(_) => {}
125 _ => return Err(ErrorCode::InvalidState),
126 }
127
128 self.connect_common(addr)
129 }
130
131 #[cfg(feature = "p3")]
133 pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134 match self.udp_state {
135 UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
136 _ => return Err(ErrorCode::InvalidState),
137 }
138
139 self.connect_common(addr)
140 }
141
142 fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
143 if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
144 return Err(ErrorCode::InvalidArgument);
145 }
146
147 match udp_connect(&self.socket, addr) {
148 Ok(()) => {
149 self.udp_state = UdpState::Connected(addr);
150 Ok(())
151 }
152 Err(e) => {
153 _ = udp_disconnect(&self.socket);
155 self.udp_state = UdpState::Bound;
156
157 Err(match e {
158 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
160 debug!("UDP connect returned EINPROGRESS, which should never happen");
161 ErrorCode::Unknown
162 }
163 err => err.into(),
164 })
165 }
166 }
167 }
168
169 #[cfg(feature = "p3")]
171 pub(crate) fn send_p3(
172 &mut self,
173 buf: Vec<u8>,
174 addr: Option<SocketAddr>,
175 ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
176 enum Mode {
177 Send(Arc<tokio::net::UdpSocket>),
178 SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
179 }
180 let socket = match (&self.udp_state, addr) {
181 (UdpState::BindStarted, _) | (UdpState::Default, Some(_)) => {
182 Err(ErrorCode::InvalidState)
183 }
184 (UdpState::Bound | UdpState::Default, None) => Err(ErrorCode::InvalidArgument),
185 (UdpState::Bound, Some(addr)) => {
186 if is_valid_remote_address(addr) && is_valid_address_family(addr.ip(), self.family)
187 {
188 Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
189 } else {
190 Err(ErrorCode::InvalidArgument)
191 }
192 }
193 (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
194 (UdpState::Connected(caddr), Some(addr)) => {
195 if addr == *caddr {
196 Ok(Mode::Send(Arc::clone(&self.socket)))
197 } else {
198 Err(ErrorCode::InvalidArgument)
199 }
200 }
201 };
202
203 async move {
204 match socket? {
205 Mode::Send(socket) => send(&socket, &buf).await,
206 Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
207 }
208 }
209 }
210
211 #[cfg(feature = "p3")]
213 pub(crate) fn receive_p3(
214 &self,
215 ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
216 enum Mode {
217 Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
218 RecvFrom(Arc<tokio::net::UdpSocket>),
219 }
220 let socket = match self.udp_state {
221 UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
222 UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
223 UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
224 };
225 async move {
226 let socket = socket?;
227 let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
228 let (n, addr) = match socket {
229 Mode::Recv(socket, addr) => {
230 let n = socket.recv(&mut buf).await?;
231 (n, addr)
232 }
233 Mode::RecvFrom(socket) => {
234 let (n, addr) = socket.recv_from(&mut buf).await?;
235 (n, addr)
236 }
237 };
238 buf.truncate(n);
239 Ok((buf, addr))
240 }
241 }
242
243 pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
244 if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
245 return Err(ErrorCode::InvalidState);
246 }
247 let addr = self
248 .socket
249 .as_socketlike_view::<std::net::UdpSocket>()
250 .local_addr()?;
251 Ok(addr)
252 }
253
254 pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
255 if !matches!(self.udp_state, UdpState::Connected(..)) {
256 return Err(ErrorCode::InvalidState);
257 }
258 let addr = self
259 .socket
260 .as_socketlike_view::<std::net::UdpSocket>()
261 .peer_addr()?;
262 Ok(addr)
263 }
264
265 pub(crate) fn address_family(&self) -> SocketAddressFamily {
266 self.family
267 }
268
269 pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
270 let n = get_unicast_hop_limit(&self.socket, self.family)?;
271 Ok(n)
272 }
273
274 pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
275 set_unicast_hop_limit(&self.socket, self.family, value)?;
276 Ok(())
277 }
278
279 pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
280 let n = receive_buffer_size(&self.socket)?;
281 Ok(n)
282 }
283
284 pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
285 set_receive_buffer_size(&self.socket, value)?;
286 Ok(())
287 }
288
289 pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
290 let n = send_buffer_size(&self.socket)?;
291 Ok(n)
292 }
293
294 pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
295 set_send_buffer_size(&self.socket, value)?;
296 Ok(())
297 }
298
299 pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
300 &self.socket
301 }
302
303 pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
304 self.socket_addr_check.as_ref()
305 }
306
307 pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
308 self.socket_addr_check = check;
309 }
310}
311
312#[cfg(feature = "p3")]
313async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
314 let n = socket.send(buf).await?;
315 if n != buf.len() {
321 Err(ErrorCode::Unknown)
322 } else {
323 Ok(())
324 }
325}
326
327#[cfg(feature = "p3")]
328async fn send_to(
329 socket: &tokio::net::UdpSocket,
330 buf: &[u8],
331 addr: SocketAddr,
332) -> Result<(), ErrorCode> {
333 let n = socket.send_to(buf, addr).await?;
334 if n != buf.len() {
336 Err(ErrorCode::Unknown)
337 } else {
338 Ok(())
339 }
340}