1use crate::runtime::with_ambient_tokio_runtime;
2use crate::sockets::util::{
3 ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
4 receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5 set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
6};
7use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
8use io_lifetimes::AsSocketlike as _;
9use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
10use rustix::io::Errno;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use tracing::debug;
14
15enum UdpState {
20 Default,
22
23 BindStarted,
26
27 Bound,
30
31 #[cfg_attr(
33 not(feature = "p3"),
34 expect(dead_code, reason = "p2 has its own way of managing sending/receiving")
35 )]
36 Connected(SocketAddr),
37}
38
39pub struct UdpSocket {
44 socket: Arc<tokio::net::UdpSocket>,
45
46 udp_state: UdpState,
48
49 family: SocketAddressFamily,
51
52 socket_addr_check: Option<SocketAddrCheck>,
55}
56
57impl UdpSocket {
58 pub(crate) fn new(cx: &WasiSocketsCtx, family: SocketAddressFamily) -> Result<Self, ErrorCode> {
60 cx.allowed_network_uses.check_allowed_udp()?;
61
62 let fd = udp_socket(family)?;
63
64 if family == SocketAddressFamily::Ipv6 {
65 rustix::net::sockopt::set_ipv6_v6only(&fd, true)?;
66 }
67
68 let socket = with_ambient_tokio_runtime(|| {
69 tokio::net::UdpSocket::try_from(unsafe {
70 std::net::UdpSocket::from_raw_socketlike(fd.into_raw_socketlike())
71 })
72 })?;
73
74 Ok(Self {
75 socket: Arc::new(socket),
76 udp_state: UdpState::Default,
77 family,
78 socket_addr_check: None,
79 })
80 }
81
82 pub(crate) fn bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
83 if !matches!(self.udp_state, UdpState::Default) {
84 return Err(ErrorCode::InvalidState);
85 }
86 if !is_valid_address_family(addr.ip(), self.family) {
87 return Err(ErrorCode::InvalidArgument);
88 }
89 udp_bind(&self.socket, addr)?;
90 self.udp_state = UdpState::BindStarted;
91 Ok(())
92 }
93
94 pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
95 match self.udp_state {
96 UdpState::BindStarted => {
97 self.udp_state = UdpState::Bound;
98 Ok(())
99 }
100 _ => Err(ErrorCode::NotInProgress),
101 }
102 }
103
104 pub(crate) fn is_connected(&self) -> bool {
105 matches!(self.udp_state, UdpState::Connected(..))
106 }
107
108 pub(crate) fn is_bound(&self) -> bool {
109 matches!(self.udp_state, UdpState::Connected(..) | UdpState::Bound)
110 }
111
112 pub(crate) fn disconnect(&mut self) -> Result<(), ErrorCode> {
113 if !self.is_connected() {
114 return Err(ErrorCode::InvalidState);
115 }
116 udp_disconnect(&self.socket)?;
117 self.udp_state = UdpState::Bound;
118 Ok(())
119 }
120
121 pub(crate) fn connect_p2(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
123 match self.udp_state {
124 UdpState::Bound | UdpState::Connected(_) => {}
125 _ => return Err(ErrorCode::InvalidState),
126 }
127
128 self.connect_common(addr)
129 }
130
131 #[cfg(feature = "p3")]
133 pub(crate) fn connect_p3(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
134 match self.udp_state {
135 UdpState::Default | UdpState::Bound | UdpState::Connected(_) => {}
136 _ => return Err(ErrorCode::InvalidState),
137 }
138
139 self.connect_common(addr)
140 }
141
142 fn connect_common(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
143 if !is_valid_address_family(addr.ip(), self.family) || !is_valid_remote_address(addr) {
144 return Err(ErrorCode::InvalidArgument);
145 }
146
147 match udp_connect(&self.socket, addr) {
148 Ok(()) => {
149 self.udp_state = UdpState::Connected(addr);
150 Ok(())
151 }
152 Err(e) => {
153 _ = udp_disconnect(&self.socket);
155 self.udp_state = UdpState::Bound;
156
157 Err(match e {
158 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, Errno::INPROGRESS => {
160 debug!("UDP connect returned EINPROGRESS, which should never happen");
161 ErrorCode::Unknown
162 }
163 err => err.into(),
164 })
165 }
166 }
167 }
168
169 #[cfg(feature = "p3")]
171 pub(crate) fn send_p3(
172 &mut self,
173 buf: Vec<u8>,
174 addr: Option<SocketAddr>,
175 ) -> impl Future<Output = Result<(), ErrorCode>> + use<> {
176 enum Mode {
177 Send(Arc<tokio::net::UdpSocket>),
178 SendTo(Arc<tokio::net::UdpSocket>, SocketAddr),
179 }
180 let mut socket = match (&self.udp_state, addr) {
181 (UdpState::BindStarted, _) => Err(ErrorCode::InvalidState),
182 (UdpState::Default | UdpState::Bound, None) => Err(ErrorCode::InvalidArgument),
183 (UdpState::Default | UdpState::Bound, Some(addr)) => {
184 Ok(Mode::SendTo(Arc::clone(&self.socket), addr))
185 }
186 (UdpState::Connected(..), None) => Ok(Mode::Send(Arc::clone(&self.socket))),
187 (UdpState::Connected(caddr), Some(addr)) => {
188 if addr == *caddr {
189 Ok(Mode::Send(Arc::clone(&self.socket)))
190 } else {
191 Err(ErrorCode::InvalidArgument)
192 }
193 }
194 };
195
196 if socket.is_ok()
205 && let UdpState::Default = self.udp_state
206 {
207 let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
208 match udp_bind(&self.socket, implicit_addr) {
209 Ok(()) => {
210 self.udp_state = UdpState::Bound;
211 }
212 Err(e) => {
213 socket = Err(e);
214 }
215 }
216 }
217
218 async move {
219 match socket? {
220 Mode::Send(socket) => send(&socket, &buf).await,
221 Mode::SendTo(socket, addr) => send_to(&socket, &buf, addr).await,
222 }
223 }
224 }
225
226 #[cfg(feature = "p3")]
228 pub(crate) fn receive_p3(
229 &self,
230 ) -> impl Future<Output = Result<(Vec<u8>, SocketAddr), ErrorCode>> + use<> {
231 enum Mode {
232 Recv(Arc<tokio::net::UdpSocket>, SocketAddr),
233 RecvFrom(Arc<tokio::net::UdpSocket>),
234 }
235 let socket = match self.udp_state {
236 UdpState::Default | UdpState::BindStarted => Err(ErrorCode::InvalidState),
237 UdpState::Bound => Ok(Mode::RecvFrom(Arc::clone(&self.socket))),
238 UdpState::Connected(addr) => Ok(Mode::Recv(Arc::clone(&self.socket), addr)),
239 };
240 async move {
241 let socket = socket?;
242 let mut buf = vec![0; super::MAX_UDP_DATAGRAM_SIZE];
243 let (n, addr) = match socket {
244 Mode::Recv(socket, addr) => {
245 let n = socket.recv(&mut buf).await?;
246 (n, addr)
247 }
248 Mode::RecvFrom(socket) => {
249 let (n, addr) = socket.recv_from(&mut buf).await?;
250 (n, addr)
251 }
252 };
253 buf.truncate(n);
254 Ok((buf, addr))
255 }
256 }
257
258 pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
259 if matches!(self.udp_state, UdpState::Default | UdpState::BindStarted) {
260 return Err(ErrorCode::InvalidState);
261 }
262 let addr = self
263 .socket
264 .as_socketlike_view::<std::net::UdpSocket>()
265 .local_addr()?;
266 Ok(addr)
267 }
268
269 pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
270 if !matches!(self.udp_state, UdpState::Connected(..)) {
271 return Err(ErrorCode::InvalidState);
272 }
273 let addr = self
274 .socket
275 .as_socketlike_view::<std::net::UdpSocket>()
276 .peer_addr()?;
277 Ok(addr)
278 }
279
280 pub(crate) fn address_family(&self) -> SocketAddressFamily {
281 self.family
282 }
283
284 pub(crate) fn unicast_hop_limit(&self) -> Result<u8, ErrorCode> {
285 let n = get_unicast_hop_limit(&self.socket, self.family)?;
286 Ok(n)
287 }
288
289 pub(crate) fn set_unicast_hop_limit(&self, value: u8) -> Result<(), ErrorCode> {
290 set_unicast_hop_limit(&self.socket, self.family, value)?;
291 Ok(())
292 }
293
294 pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
295 let n = receive_buffer_size(&self.socket)?;
296 Ok(n)
297 }
298
299 pub(crate) fn set_receive_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
300 set_receive_buffer_size(&self.socket, value)?;
301 Ok(())
302 }
303
304 pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
305 let n = send_buffer_size(&self.socket)?;
306 Ok(n)
307 }
308
309 pub(crate) fn set_send_buffer_size(&self, value: u64) -> Result<(), ErrorCode> {
310 set_send_buffer_size(&self.socket, value)?;
311 Ok(())
312 }
313
314 pub(crate) fn socket(&self) -> &Arc<tokio::net::UdpSocket> {
315 &self.socket
316 }
317
318 pub(crate) fn socket_addr_check(&self) -> Option<&SocketAddrCheck> {
319 self.socket_addr_check.as_ref()
320 }
321
322 pub(crate) fn set_socket_addr_check(&mut self, check: Option<SocketAddrCheck>) {
323 self.socket_addr_check = check;
324 }
325}
326
327#[cfg(feature = "p3")]
328async fn send(socket: &tokio::net::UdpSocket, buf: &[u8]) -> Result<(), ErrorCode> {
329 let n = socket.send(buf).await?;
330 if n != buf.len() {
336 Err(ErrorCode::Unknown)
337 } else {
338 Ok(())
339 }
340}
341
342#[cfg(feature = "p3")]
343async fn send_to(
344 socket: &tokio::net::UdpSocket,
345 buf: &[u8],
346 addr: SocketAddr,
347) -> Result<(), ErrorCode> {
348 let n = socket.send_to(buf, addr).await?;
349 if n != buf.len() {
351 Err(ErrorCode::Unknown)
352 } else {
353 Ok(())
354 }
355}