wasmtime_wasi/p2/host/
udp.rs

1use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
2use crate::p2::bindings::sockets::udp;
3use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState};
4use crate::p2::{Pollable, SocketError, SocketResult};
5use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
6use crate::sockets::{
7    MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
8};
9use anyhow::anyhow;
10use async_trait::async_trait;
11use std::net::SocketAddr;
12use tokio::io::Interest;
13use wasmtime::component::Resource;
14use wasmtime_wasi_io::poll::DynPollable;
15
16impl udp::Host for WasiSocketsCtxView<'_> {}
17
18impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
19    async fn start_bind(
20        &mut self,
21        this: Resource<udp::UdpSocket>,
22        network: Resource<Network>,
23        local_address: IpSocketAddress,
24    ) -> SocketResult<()> {
25        let local_address = SocketAddr::from(local_address);
26        let check = self.table.get(&network)?.socket_addr_check.clone();
27        check.check(local_address, SocketAddrUse::UdpBind).await?;
28
29        let socket = self.table.get_mut(&this)?;
30        socket.bind(local_address)?;
31        socket.set_socket_addr_check(Some(check));
32
33        Ok(())
34    }
35
36    fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
37        self.table.get_mut(&this)?.finish_bind()?;
38        Ok(())
39    }
40
41    async fn stream(
42        &mut self,
43        this: Resource<udp::UdpSocket>,
44        remote_address: Option<IpSocketAddress>,
45    ) -> SocketResult<(
46        Resource<udp::IncomingDatagramStream>,
47        Resource<udp::OutgoingDatagramStream>,
48    )> {
49        let has_active_streams = self
50            .table
51            .iter_children(&this)?
52            .any(|c| c.is::<IncomingDatagramStream>() || c.is::<OutgoingDatagramStream>());
53
54        if has_active_streams {
55            return Err(SocketError::trap(anyhow!("UDP streams not dropped yet")));
56        }
57
58        let socket = self.table.get_mut(&this)?;
59        let remote_address = remote_address.map(SocketAddr::from);
60
61        if !socket.is_bound() {
62            return Err(ErrorCode::InvalidState.into());
63        }
64
65        // We disconnect & (re)connect in two distinct steps for two reasons:
66        // - To leave our socket instance in a consistent state in case the
67        //   connect fails.
68        // - When reconnecting to a different address, Linux sometimes fails
69        //   if there isn't a disconnect in between.
70
71        // Step #1: Disconnect
72        if socket.is_connected() {
73            socket.disconnect()?;
74        }
75
76        // Step #2: (Re)connect
77        if let Some(connect_addr) = remote_address {
78            let Some(check) = socket.socket_addr_check() else {
79                return Err(ErrorCode::InvalidState.into());
80            };
81            check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
82            socket.connect(connect_addr)?;
83        }
84
85        let incoming_stream = IncomingDatagramStream {
86            inner: socket.socket().clone(),
87            remote_address,
88        };
89        let outgoing_stream = OutgoingDatagramStream {
90            inner: socket.socket().clone(),
91            remote_address,
92            family: socket.address_family(),
93            send_state: SendState::Idle,
94            socket_addr_check: socket.socket_addr_check().cloned(),
95        };
96
97        Ok((
98            self.table.push_child(incoming_stream, &this)?,
99            self.table.push_child(outgoing_stream, &this)?,
100        ))
101    }
102
103    fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
104        let socket = self.table.get(&this)?;
105        Ok(socket.local_address()?.into())
106    }
107
108    fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
109        let socket = self.table.get(&this)?;
110        Ok(socket.remote_address()?.into())
111    }
112
113    fn address_family(
114        &mut self,
115        this: Resource<udp::UdpSocket>,
116    ) -> Result<IpAddressFamily, anyhow::Error> {
117        let socket = self.table.get(&this)?;
118        Ok(socket.address_family().into())
119    }
120
121    fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
122        let socket = self.table.get(&this)?;
123        Ok(socket.unicast_hop_limit()?)
124    }
125
126    fn set_unicast_hop_limit(
127        &mut self,
128        this: Resource<udp::UdpSocket>,
129        value: u8,
130    ) -> SocketResult<()> {
131        let socket = self.table.get(&this)?;
132        socket.set_unicast_hop_limit(value)?;
133        Ok(())
134    }
135
136    fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
137        let socket = self.table.get(&this)?;
138        Ok(socket.receive_buffer_size()?)
139    }
140
141    fn set_receive_buffer_size(
142        &mut self,
143        this: Resource<udp::UdpSocket>,
144        value: u64,
145    ) -> SocketResult<()> {
146        let socket = self.table.get(&this)?;
147        socket.set_receive_buffer_size(value)?;
148        Ok(())
149    }
150
151    fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
152        let socket = self.table.get(&this)?;
153        Ok(socket.send_buffer_size()?)
154    }
155
156    fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
157        let socket = self.table.get(&this)?;
158        socket.set_send_buffer_size(value)?;
159        Ok(())
160    }
161
162    fn subscribe(&mut self, this: Resource<UdpSocket>) -> anyhow::Result<Resource<DynPollable>> {
163        wasmtime_wasi_io::poll::subscribe(self.table, this)
164    }
165
166    fn drop(&mut self, this: Resource<udp::UdpSocket>) -> Result<(), anyhow::Error> {
167        // As in the filesystem implementation, we assume closing a socket
168        // doesn't block.
169        let dropped = self.table.delete(this)?;
170        drop(dropped);
171
172        Ok(())
173    }
174}
175
176#[async_trait]
177impl Pollable for UdpSocket {
178    async fn ready(&mut self) {
179        // None of the socket-level operations block natively
180    }
181}
182
183impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
184    fn receive(
185        &mut self,
186        this: Resource<udp::IncomingDatagramStream>,
187        max_results: u64,
188    ) -> SocketResult<Vec<udp::IncomingDatagram>> {
189        // Returns Ok(None) when the message was dropped.
190        fn recv_one(
191            stream: &IncomingDatagramStream,
192        ) -> SocketResult<Option<udp::IncomingDatagram>> {
193            let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
194            let (size, received_addr) = stream.inner.try_recv_from(&mut buf)?;
195            debug_assert!(size <= buf.len());
196
197            match stream.remote_address {
198                Some(connected_addr) if connected_addr != received_addr => {
199                    // Normally, this should have already been checked for us by the OS.
200                    return Ok(None);
201                }
202                _ => {}
203            }
204
205            Ok(Some(udp::IncomingDatagram {
206                data: buf[..size].into(),
207                remote_address: received_addr.into(),
208            }))
209        }
210
211        let stream = self.table.get(&this)?;
212        let max_results: usize = max_results.try_into().unwrap_or(usize::MAX);
213
214        if max_results == 0 {
215            return Ok(vec![]);
216        }
217
218        let mut datagrams = vec![];
219
220        while datagrams.len() < max_results {
221            match recv_one(stream) {
222                Ok(Some(datagram)) => {
223                    datagrams.push(datagram);
224                }
225                Ok(None) => {
226                    // Message was dropped
227                }
228                Err(_) if datagrams.len() > 0 => {
229                    return Ok(datagrams);
230                }
231                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
232                    return Ok(datagrams);
233                }
234                Err(e) => {
235                    return Err(e);
236                }
237            }
238        }
239
240        Ok(datagrams)
241    }
242
243    fn subscribe(
244        &mut self,
245        this: Resource<udp::IncomingDatagramStream>,
246    ) -> anyhow::Result<Resource<DynPollable>> {
247        wasmtime_wasi_io::poll::subscribe(self.table, this)
248    }
249
250    fn drop(&mut self, this: Resource<udp::IncomingDatagramStream>) -> Result<(), anyhow::Error> {
251        // As in the filesystem implementation, we assume closing a socket
252        // doesn't block.
253        let dropped = self.table.delete(this)?;
254        drop(dropped);
255
256        Ok(())
257    }
258}
259
260#[async_trait]
261impl Pollable for IncomingDatagramStream {
262    async fn ready(&mut self) {
263        // FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
264        self.inner
265            .ready(Interest::READABLE)
266            .await
267            .expect("failed to await UDP socket readiness");
268    }
269}
270
271impl udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
272    fn check_send(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> SocketResult<u64> {
273        let stream = self.table.get_mut(&this)?;
274
275        let permit = match stream.send_state {
276            SendState::Idle => {
277                const PERMIT: usize = 16;
278                stream.send_state = SendState::Permitted(PERMIT);
279                PERMIT
280            }
281            SendState::Permitted(n) => n,
282            SendState::Waiting => 0,
283        };
284
285        Ok(permit.try_into().unwrap())
286    }
287
288    async fn send(
289        &mut self,
290        this: Resource<udp::OutgoingDatagramStream>,
291        datagrams: Vec<udp::OutgoingDatagram>,
292    ) -> SocketResult<u64> {
293        async fn send_one(
294            stream: &OutgoingDatagramStream,
295            datagram: &udp::OutgoingDatagram,
296        ) -> SocketResult<()> {
297            if datagram.data.len() > MAX_UDP_DATAGRAM_SIZE {
298                return Err(ErrorCode::DatagramTooLarge.into());
299            }
300
301            let provided_addr = datagram.remote_address.map(SocketAddr::from);
302            let addr = match (stream.remote_address, provided_addr) {
303                (None, Some(addr)) => {
304                    let Some(check) = stream.socket_addr_check.as_ref() else {
305                        return Err(ErrorCode::InvalidState.into());
306                    };
307                    check
308                        .check(addr, SocketAddrUse::UdpOutgoingDatagram)
309                        .await?;
310                    addr
311                }
312                (Some(addr), None) => addr,
313                (Some(connected_addr), Some(provided_addr)) if connected_addr == provided_addr => {
314                    connected_addr
315                }
316                _ => return Err(ErrorCode::InvalidArgument.into()),
317            };
318
319            if !is_valid_remote_address(addr) || !is_valid_address_family(addr.ip(), stream.family)
320            {
321                return Err(ErrorCode::InvalidArgument.into());
322            }
323
324            if stream.remote_address == Some(addr) {
325                stream.inner.try_send(&datagram.data)?;
326            } else {
327                stream.inner.try_send_to(&datagram.data, addr)?;
328            }
329
330            Ok(())
331        }
332
333        let stream = self.table.get_mut(&this)?;
334
335        match stream.send_state {
336            SendState::Permitted(n) if n >= datagrams.len() => {
337                stream.send_state = SendState::Idle;
338            }
339            SendState::Permitted(_) => {
340                return Err(SocketError::trap(anyhow::anyhow!(
341                    "unpermitted: argument exceeds permitted size"
342                )));
343            }
344            SendState::Idle | SendState::Waiting => {
345                return Err(SocketError::trap(anyhow::anyhow!(
346                    "unpermitted: must call check-send first"
347                )));
348            }
349        }
350
351        if datagrams.is_empty() {
352            return Ok(0);
353        }
354
355        let mut count = 0;
356
357        for datagram in datagrams {
358            match send_one(stream, &datagram).await {
359                Ok(_) => count += 1,
360                Err(_) if count > 0 => {
361                    // WIT: "If at least one datagram has been sent successfully, this function never returns an error."
362                    return Ok(count);
363                }
364                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
365                    stream.send_state = SendState::Waiting;
366                    return Ok(count);
367                }
368                Err(e) => {
369                    return Err(e);
370                }
371            }
372        }
373
374        Ok(count)
375    }
376
377    fn subscribe(
378        &mut self,
379        this: Resource<udp::OutgoingDatagramStream>,
380    ) -> anyhow::Result<Resource<DynPollable>> {
381        wasmtime_wasi_io::poll::subscribe(self.table, this)
382    }
383
384    fn drop(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> Result<(), anyhow::Error> {
385        // As in the filesystem implementation, we assume closing a socket
386        // doesn't block.
387        let dropped = self.table.delete(this)?;
388        drop(dropped);
389
390        Ok(())
391    }
392}
393
394#[async_trait]
395impl Pollable for OutgoingDatagramStream {
396    async fn ready(&mut self) {
397        match self.send_state {
398            SendState::Idle | SendState::Permitted(_) => {}
399            SendState::Waiting => {
400                // FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
401                self.inner
402                    .ready(Interest::WRITABLE)
403                    .await
404                    .expect("failed to await UDP socket readiness");
405                self.send_state = SendState::Idle;
406            }
407        }
408    }
409}
410
411impl From<SocketAddressFamily> for IpAddressFamily {
412    fn from(family: SocketAddressFamily) -> IpAddressFamily {
413        match family {
414            SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
415            SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
416        }
417    }
418}
419
420pub mod sync {
421    use wasmtime::component::Resource;
422
423    use crate::p2::{
424        SocketError,
425        bindings::{
426            sockets::{
427                network::Network,
428                udp::{
429                    self as async_udp,
430                    HostIncomingDatagramStream as AsyncHostIncomingDatagramStream,
431                    HostOutgoingDatagramStream as AsyncHostOutgoingDatagramStream,
432                    HostUdpSocket as AsyncHostUdpSocket, IncomingDatagramStream,
433                    OutgoingDatagramStream,
434                },
435            },
436            sync::sockets::udp::{
437                self, HostIncomingDatagramStream, HostOutgoingDatagramStream, HostUdpSocket,
438                IncomingDatagram, IpAddressFamily, IpSocketAddress, OutgoingDatagram, Pollable,
439                UdpSocket,
440            },
441        },
442    };
443    use crate::runtime::in_tokio;
444    use crate::sockets::WasiSocketsCtxView;
445
446    impl udp::Host for WasiSocketsCtxView<'_> {}
447
448    impl HostUdpSocket for WasiSocketsCtxView<'_> {
449        fn start_bind(
450            &mut self,
451            self_: Resource<UdpSocket>,
452            network: Resource<Network>,
453            local_address: IpSocketAddress,
454        ) -> Result<(), SocketError> {
455            in_tokio(async {
456                AsyncHostUdpSocket::start_bind(self, self_, network, local_address).await
457            })
458        }
459
460        fn finish_bind(&mut self, self_: Resource<UdpSocket>) -> Result<(), SocketError> {
461            AsyncHostUdpSocket::finish_bind(self, self_)
462        }
463
464        fn stream(
465            &mut self,
466            self_: Resource<UdpSocket>,
467            remote_address: Option<IpSocketAddress>,
468        ) -> Result<
469            (
470                Resource<IncomingDatagramStream>,
471                Resource<OutgoingDatagramStream>,
472            ),
473            SocketError,
474        > {
475            in_tokio(async { AsyncHostUdpSocket::stream(self, self_, remote_address).await })
476        }
477
478        fn local_address(
479            &mut self,
480            self_: Resource<UdpSocket>,
481        ) -> Result<IpSocketAddress, SocketError> {
482            AsyncHostUdpSocket::local_address(self, self_)
483        }
484
485        fn remote_address(
486            &mut self,
487            self_: Resource<UdpSocket>,
488        ) -> Result<IpSocketAddress, SocketError> {
489            AsyncHostUdpSocket::remote_address(self, self_)
490        }
491
492        fn address_family(
493            &mut self,
494            self_: Resource<UdpSocket>,
495        ) -> wasmtime::Result<IpAddressFamily> {
496            AsyncHostUdpSocket::address_family(self, self_)
497        }
498
499        fn unicast_hop_limit(&mut self, self_: Resource<UdpSocket>) -> Result<u8, SocketError> {
500            AsyncHostUdpSocket::unicast_hop_limit(self, self_)
501        }
502
503        fn set_unicast_hop_limit(
504            &mut self,
505            self_: Resource<UdpSocket>,
506            value: u8,
507        ) -> Result<(), SocketError> {
508            AsyncHostUdpSocket::set_unicast_hop_limit(self, self_, value)
509        }
510
511        fn receive_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
512            AsyncHostUdpSocket::receive_buffer_size(self, self_)
513        }
514
515        fn set_receive_buffer_size(
516            &mut self,
517            self_: Resource<UdpSocket>,
518            value: u64,
519        ) -> Result<(), SocketError> {
520            AsyncHostUdpSocket::set_receive_buffer_size(self, self_, value)
521        }
522
523        fn send_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
524            AsyncHostUdpSocket::send_buffer_size(self, self_)
525        }
526
527        fn set_send_buffer_size(
528            &mut self,
529            self_: Resource<UdpSocket>,
530            value: u64,
531        ) -> Result<(), SocketError> {
532            AsyncHostUdpSocket::set_send_buffer_size(self, self_, value)
533        }
534
535        fn subscribe(
536            &mut self,
537            self_: Resource<UdpSocket>,
538        ) -> wasmtime::Result<Resource<Pollable>> {
539            AsyncHostUdpSocket::subscribe(self, self_)
540        }
541
542        fn drop(&mut self, rep: Resource<UdpSocket>) -> wasmtime::Result<()> {
543            AsyncHostUdpSocket::drop(self, rep)
544        }
545    }
546
547    impl HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
548        fn receive(
549            &mut self,
550            self_: Resource<IncomingDatagramStream>,
551            max_results: u64,
552        ) -> Result<Vec<IncomingDatagram>, SocketError> {
553            Ok(
554                AsyncHostIncomingDatagramStream::receive(self, self_, max_results)?
555                    .into_iter()
556                    .map(Into::into)
557                    .collect(),
558            )
559        }
560
561        fn subscribe(
562            &mut self,
563            self_: Resource<IncomingDatagramStream>,
564        ) -> wasmtime::Result<Resource<Pollable>> {
565            AsyncHostIncomingDatagramStream::subscribe(self, self_)
566        }
567
568        fn drop(&mut self, rep: Resource<IncomingDatagramStream>) -> wasmtime::Result<()> {
569            AsyncHostIncomingDatagramStream::drop(self, rep)
570        }
571    }
572
573    impl From<async_udp::IncomingDatagram> for IncomingDatagram {
574        fn from(other: async_udp::IncomingDatagram) -> Self {
575            let async_udp::IncomingDatagram {
576                data,
577                remote_address,
578            } = other;
579            Self {
580                data,
581                remote_address,
582            }
583        }
584    }
585
586    impl HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
587        fn check_send(
588            &mut self,
589            self_: Resource<OutgoingDatagramStream>,
590        ) -> Result<u64, SocketError> {
591            AsyncHostOutgoingDatagramStream::check_send(self, self_)
592        }
593
594        fn send(
595            &mut self,
596            self_: Resource<OutgoingDatagramStream>,
597            datagrams: Vec<OutgoingDatagram>,
598        ) -> Result<u64, SocketError> {
599            let datagrams = datagrams.into_iter().map(Into::into).collect();
600            in_tokio(async { AsyncHostOutgoingDatagramStream::send(self, self_, datagrams).await })
601        }
602
603        fn subscribe(
604            &mut self,
605            self_: Resource<OutgoingDatagramStream>,
606        ) -> wasmtime::Result<Resource<Pollable>> {
607            AsyncHostOutgoingDatagramStream::subscribe(self, self_)
608        }
609
610        fn drop(&mut self, rep: Resource<OutgoingDatagramStream>) -> wasmtime::Result<()> {
611            AsyncHostOutgoingDatagramStream::drop(self, rep)
612        }
613    }
614
615    impl From<OutgoingDatagram> for async_udp::OutgoingDatagram {
616        fn from(other: OutgoingDatagram) -> Self {
617            let OutgoingDatagram {
618                data,
619                remote_address,
620            } = other;
621            Self {
622                data,
623                remote_address,
624            }
625        }
626    }
627}