Skip to main content

wasmtime_wasi/p2/host/
udp.rs

1use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
2use crate::p2::bindings::sockets::udp;
3use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream, SendState};
4use crate::p2::{Pollable, SocketError, SocketResult};
5use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
6use crate::sockets::{
7    MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
8};
9use async_trait::async_trait;
10use std::net::SocketAddr;
11use tokio::io::Interest;
12use wasmtime::component::Resource;
13use wasmtime::format_err;
14use wasmtime_wasi_io::poll::DynPollable;
15
16impl udp::Host for WasiSocketsCtxView<'_> {}
17
18impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
19    async fn start_bind(
20        &mut self,
21        this: Resource<udp::UdpSocket>,
22        network: Resource<Network>,
23        local_address: IpSocketAddress,
24    ) -> SocketResult<()> {
25        let local_address = SocketAddr::from(local_address);
26        let check = self.table.get(&network)?.socket_addr_check.clone();
27        check.check(local_address, SocketAddrUse::UdpBind).await?;
28
29        let socket = self.table.get_mut(&this)?;
30        socket.bind(local_address)?;
31        socket.set_socket_addr_check(Some(check));
32
33        Ok(())
34    }
35
36    fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
37        self.table.get_mut(&this)?.finish_bind()?;
38        Ok(())
39    }
40
41    async fn stream(
42        &mut self,
43        this: Resource<udp::UdpSocket>,
44        remote_address: Option<IpSocketAddress>,
45    ) -> SocketResult<(
46        Resource<udp::IncomingDatagramStream>,
47        Resource<udp::OutgoingDatagramStream>,
48    )> {
49        let has_active_streams = self
50            .table
51            .iter_children(&this)?
52            .any(|c| c.is::<IncomingDatagramStream>() || c.is::<OutgoingDatagramStream>());
53
54        if has_active_streams {
55            return Err(SocketError::trap(format_err!(
56                "UDP streams not dropped yet"
57            )));
58        }
59
60        let socket = self.table.get_mut(&this)?;
61        let remote_address = remote_address.map(SocketAddr::from);
62
63        if !socket.is_bound() {
64            return Err(ErrorCode::InvalidState.into());
65        }
66
67        // We disconnect & (re)connect in two distinct steps for two reasons:
68        // - To leave our socket instance in a consistent state in case the
69        //   connect fails.
70        // - When reconnecting to a different address, Linux sometimes fails
71        //   if there isn't a disconnect in between.
72
73        // Step #1: Disconnect
74        if socket.is_connected() {
75            socket.disconnect()?;
76        }
77
78        // Step #2: (Re)connect
79        if let Some(connect_addr) = remote_address {
80            let Some(check) = socket.socket_addr_check() else {
81                return Err(ErrorCode::InvalidState.into());
82            };
83            check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
84            socket.connect_p2(connect_addr)?;
85        }
86
87        let incoming_stream = IncomingDatagramStream {
88            inner: socket.socket().clone(),
89            remote_address,
90        };
91        let outgoing_stream = OutgoingDatagramStream {
92            inner: socket.socket().clone(),
93            remote_address,
94            family: socket.address_family(),
95            send_state: SendState::Idle,
96            socket_addr_check: socket.socket_addr_check().cloned(),
97        };
98
99        Ok((
100            self.table.push_child(incoming_stream, &this)?,
101            self.table.push_child(outgoing_stream, &this)?,
102        ))
103    }
104
105    fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
106        let socket = self.table.get(&this)?;
107        Ok(socket.local_address()?.into())
108    }
109
110    fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
111        let socket = self.table.get(&this)?;
112        Ok(socket.remote_address()?.into())
113    }
114
115    fn address_family(
116        &mut self,
117        this: Resource<udp::UdpSocket>,
118    ) -> Result<IpAddressFamily, wasmtime::Error> {
119        let socket = self.table.get(&this)?;
120        Ok(socket.address_family().into())
121    }
122
123    fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
124        let socket = self.table.get(&this)?;
125        Ok(socket.unicast_hop_limit()?)
126    }
127
128    fn set_unicast_hop_limit(
129        &mut self,
130        this: Resource<udp::UdpSocket>,
131        value: u8,
132    ) -> SocketResult<()> {
133        let socket = self.table.get(&this)?;
134        socket.set_unicast_hop_limit(value)?;
135        Ok(())
136    }
137
138    fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
139        let socket = self.table.get(&this)?;
140        Ok(socket.receive_buffer_size()?)
141    }
142
143    fn set_receive_buffer_size(
144        &mut self,
145        this: Resource<udp::UdpSocket>,
146        value: u64,
147    ) -> SocketResult<()> {
148        let socket = self.table.get(&this)?;
149        socket.set_receive_buffer_size(value)?;
150        Ok(())
151    }
152
153    fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
154        let socket = self.table.get(&this)?;
155        Ok(socket.send_buffer_size()?)
156    }
157
158    fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
159        let socket = self.table.get(&this)?;
160        socket.set_send_buffer_size(value)?;
161        Ok(())
162    }
163
164    fn subscribe(&mut self, this: Resource<UdpSocket>) -> wasmtime::Result<Resource<DynPollable>> {
165        wasmtime_wasi_io::poll::subscribe(self.table, this)
166    }
167
168    fn drop(&mut self, this: Resource<udp::UdpSocket>) -> Result<(), wasmtime::Error> {
169        // As in the filesystem implementation, we assume closing a socket
170        // doesn't block.
171        let dropped = self.table.delete(this)?;
172        drop(dropped);
173
174        Ok(())
175    }
176}
177
178#[async_trait]
179impl Pollable for UdpSocket {
180    async fn ready(&mut self) {
181        // None of the socket-level operations block natively
182    }
183}
184
185impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
186    fn receive(
187        &mut self,
188        this: Resource<udp::IncomingDatagramStream>,
189        max_results: u64,
190    ) -> SocketResult<Vec<udp::IncomingDatagram>> {
191        // Returns Ok(None) when the message was dropped.
192        fn recv_one(
193            stream: &IncomingDatagramStream,
194        ) -> SocketResult<Option<udp::IncomingDatagram>> {
195            let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
196            let (size, received_addr) = stream.inner.try_recv_from(&mut buf)?;
197            debug_assert!(size <= buf.len());
198
199            match stream.remote_address {
200                Some(connected_addr) if connected_addr != received_addr => {
201                    // Normally, this should have already been checked for us by the OS.
202                    return Ok(None);
203                }
204                _ => {}
205            }
206
207            Ok(Some(udp::IncomingDatagram {
208                data: buf[..size].into(),
209                remote_address: received_addr.into(),
210            }))
211        }
212
213        let stream = self.table.get(&this)?;
214        let max_results: usize = max_results.try_into().unwrap_or(usize::MAX);
215
216        if max_results == 0 {
217            return Ok(vec![]);
218        }
219
220        let mut datagrams = vec![];
221
222        while datagrams.len() < max_results {
223            match recv_one(stream) {
224                Ok(Some(datagram)) => {
225                    datagrams.push(datagram);
226                }
227                Ok(None) => {
228                    // Message was dropped
229                }
230                Err(_) if datagrams.len() > 0 => {
231                    return Ok(datagrams);
232                }
233                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
234                    return Ok(datagrams);
235                }
236                Err(e) => {
237                    return Err(e);
238                }
239            }
240        }
241
242        Ok(datagrams)
243    }
244
245    fn subscribe(
246        &mut self,
247        this: Resource<udp::IncomingDatagramStream>,
248    ) -> wasmtime::Result<Resource<DynPollable>> {
249        wasmtime_wasi_io::poll::subscribe(self.table, this)
250    }
251
252    fn drop(&mut self, this: Resource<udp::IncomingDatagramStream>) -> Result<(), wasmtime::Error> {
253        // As in the filesystem implementation, we assume closing a socket
254        // doesn't block.
255        let dropped = self.table.delete(this)?;
256        drop(dropped);
257
258        Ok(())
259    }
260}
261
262#[async_trait]
263impl Pollable for IncomingDatagramStream {
264    async fn ready(&mut self) {
265        // FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
266        self.inner
267            .ready(Interest::READABLE)
268            .await
269            .expect("failed to await UDP socket readiness");
270    }
271}
272
273impl udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
274    fn check_send(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> SocketResult<u64> {
275        let stream = self.table.get_mut(&this)?;
276
277        let permit = match stream.send_state {
278            SendState::Idle => {
279                const PERMIT: usize = 16;
280                stream.send_state = SendState::Permitted(PERMIT);
281                PERMIT
282            }
283            SendState::Permitted(n) => n,
284            SendState::Waiting => 0,
285        };
286
287        Ok(permit.try_into().unwrap())
288    }
289
290    async fn send(
291        &mut self,
292        this: Resource<udp::OutgoingDatagramStream>,
293        datagrams: Vec<udp::OutgoingDatagram>,
294    ) -> SocketResult<u64> {
295        async fn send_one(
296            stream: &OutgoingDatagramStream,
297            datagram: &udp::OutgoingDatagram,
298        ) -> SocketResult<()> {
299            if datagram.data.len() > MAX_UDP_DATAGRAM_SIZE {
300                return Err(ErrorCode::DatagramTooLarge.into());
301            }
302
303            let provided_addr = datagram.remote_address.map(SocketAddr::from);
304            let addr = match (stream.remote_address, provided_addr) {
305                (None, Some(addr)) => {
306                    let Some(check) = stream.socket_addr_check.as_ref() else {
307                        return Err(ErrorCode::InvalidState.into());
308                    };
309                    check
310                        .check(addr, SocketAddrUse::UdpOutgoingDatagram)
311                        .await?;
312                    addr
313                }
314                (Some(addr), None) => addr,
315                (Some(connected_addr), Some(provided_addr)) if connected_addr == provided_addr => {
316                    connected_addr
317                }
318                _ => return Err(ErrorCode::InvalidArgument.into()),
319            };
320
321            if !is_valid_remote_address(addr) || !is_valid_address_family(addr.ip(), stream.family)
322            {
323                return Err(ErrorCode::InvalidArgument.into());
324            }
325
326            if stream.remote_address == Some(addr) {
327                stream.inner.try_send(&datagram.data)?;
328            } else {
329                stream.inner.try_send_to(&datagram.data, addr)?;
330            }
331
332            Ok(())
333        }
334
335        let stream = self.table.get_mut(&this)?;
336
337        match stream.send_state {
338            SendState::Permitted(n) if n >= datagrams.len() => {
339                stream.send_state = SendState::Idle;
340            }
341            SendState::Permitted(_) => {
342                return Err(SocketError::trap(wasmtime::format_err!(
343                    "unpermitted: argument exceeds permitted size"
344                )));
345            }
346            SendState::Idle | SendState::Waiting => {
347                return Err(SocketError::trap(wasmtime::format_err!(
348                    "unpermitted: must call check-send first"
349                )));
350            }
351        }
352
353        if datagrams.is_empty() {
354            return Ok(0);
355        }
356
357        let mut count = 0;
358
359        for datagram in datagrams {
360            match send_one(stream, &datagram).await {
361                Ok(_) => count += 1,
362                Err(_) if count > 0 => {
363                    // WIT: "If at least one datagram has been sent successfully, this function never returns an error."
364                    return Ok(count);
365                }
366                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
367                    stream.send_state = SendState::Waiting;
368                    return Ok(count);
369                }
370                Err(e) => {
371                    return Err(e);
372                }
373            }
374        }
375
376        Ok(count)
377    }
378
379    fn subscribe(
380        &mut self,
381        this: Resource<udp::OutgoingDatagramStream>,
382    ) -> wasmtime::Result<Resource<DynPollable>> {
383        wasmtime_wasi_io::poll::subscribe(self.table, this)
384    }
385
386    fn drop(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> Result<(), wasmtime::Error> {
387        // As in the filesystem implementation, we assume closing a socket
388        // doesn't block.
389        let dropped = self.table.delete(this)?;
390        drop(dropped);
391
392        Ok(())
393    }
394}
395
396#[async_trait]
397impl Pollable for OutgoingDatagramStream {
398    async fn ready(&mut self) {
399        match self.send_state {
400            SendState::Idle | SendState::Permitted(_) => {}
401            SendState::Waiting => {
402                // FIXME: Add `Interest::ERROR` when we update to tokio 1.32.
403                self.inner
404                    .ready(Interest::WRITABLE)
405                    .await
406                    .expect("failed to await UDP socket readiness");
407                self.send_state = SendState::Idle;
408            }
409        }
410    }
411}
412
413impl From<SocketAddressFamily> for IpAddressFamily {
414    fn from(family: SocketAddressFamily) -> IpAddressFamily {
415        match family {
416            SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
417            SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
418        }
419    }
420}
421
422pub mod sync {
423    use wasmtime::component::Resource;
424
425    use crate::p2::{
426        SocketError,
427        bindings::{
428            sockets::{
429                network::Network,
430                udp::{
431                    self as async_udp,
432                    HostIncomingDatagramStream as AsyncHostIncomingDatagramStream,
433                    HostOutgoingDatagramStream as AsyncHostOutgoingDatagramStream,
434                    HostUdpSocket as AsyncHostUdpSocket, IncomingDatagramStream,
435                    OutgoingDatagramStream,
436                },
437            },
438            sync::sockets::udp::{
439                self, HostIncomingDatagramStream, HostOutgoingDatagramStream, HostUdpSocket,
440                IncomingDatagram, IpAddressFamily, IpSocketAddress, OutgoingDatagram, Pollable,
441                UdpSocket,
442            },
443        },
444    };
445    use crate::runtime::in_tokio;
446    use crate::sockets::WasiSocketsCtxView;
447
448    impl udp::Host for WasiSocketsCtxView<'_> {}
449
450    impl HostUdpSocket for WasiSocketsCtxView<'_> {
451        fn start_bind(
452            &mut self,
453            self_: Resource<UdpSocket>,
454            network: Resource<Network>,
455            local_address: IpSocketAddress,
456        ) -> Result<(), SocketError> {
457            in_tokio(async {
458                AsyncHostUdpSocket::start_bind(self, self_, network, local_address).await
459            })
460        }
461
462        fn finish_bind(&mut self, self_: Resource<UdpSocket>) -> Result<(), SocketError> {
463            AsyncHostUdpSocket::finish_bind(self, self_)
464        }
465
466        fn stream(
467            &mut self,
468            self_: Resource<UdpSocket>,
469            remote_address: Option<IpSocketAddress>,
470        ) -> Result<
471            (
472                Resource<IncomingDatagramStream>,
473                Resource<OutgoingDatagramStream>,
474            ),
475            SocketError,
476        > {
477            in_tokio(async { AsyncHostUdpSocket::stream(self, self_, remote_address).await })
478        }
479
480        fn local_address(
481            &mut self,
482            self_: Resource<UdpSocket>,
483        ) -> Result<IpSocketAddress, SocketError> {
484            AsyncHostUdpSocket::local_address(self, self_)
485        }
486
487        fn remote_address(
488            &mut self,
489            self_: Resource<UdpSocket>,
490        ) -> Result<IpSocketAddress, SocketError> {
491            AsyncHostUdpSocket::remote_address(self, self_)
492        }
493
494        fn address_family(
495            &mut self,
496            self_: Resource<UdpSocket>,
497        ) -> wasmtime::Result<IpAddressFamily> {
498            AsyncHostUdpSocket::address_family(self, self_)
499        }
500
501        fn unicast_hop_limit(&mut self, self_: Resource<UdpSocket>) -> Result<u8, SocketError> {
502            AsyncHostUdpSocket::unicast_hop_limit(self, self_)
503        }
504
505        fn set_unicast_hop_limit(
506            &mut self,
507            self_: Resource<UdpSocket>,
508            value: u8,
509        ) -> Result<(), SocketError> {
510            AsyncHostUdpSocket::set_unicast_hop_limit(self, self_, value)
511        }
512
513        fn receive_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
514            AsyncHostUdpSocket::receive_buffer_size(self, self_)
515        }
516
517        fn set_receive_buffer_size(
518            &mut self,
519            self_: Resource<UdpSocket>,
520            value: u64,
521        ) -> Result<(), SocketError> {
522            AsyncHostUdpSocket::set_receive_buffer_size(self, self_, value)
523        }
524
525        fn send_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
526            AsyncHostUdpSocket::send_buffer_size(self, self_)
527        }
528
529        fn set_send_buffer_size(
530            &mut self,
531            self_: Resource<UdpSocket>,
532            value: u64,
533        ) -> Result<(), SocketError> {
534            AsyncHostUdpSocket::set_send_buffer_size(self, self_, value)
535        }
536
537        fn subscribe(
538            &mut self,
539            self_: Resource<UdpSocket>,
540        ) -> wasmtime::Result<Resource<Pollable>> {
541            AsyncHostUdpSocket::subscribe(self, self_)
542        }
543
544        fn drop(&mut self, rep: Resource<UdpSocket>) -> wasmtime::Result<()> {
545            AsyncHostUdpSocket::drop(self, rep)
546        }
547    }
548
549    impl HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
550        fn receive(
551            &mut self,
552            self_: Resource<IncomingDatagramStream>,
553            max_results: u64,
554        ) -> Result<Vec<IncomingDatagram>, SocketError> {
555            Ok(
556                AsyncHostIncomingDatagramStream::receive(self, self_, max_results)?
557                    .into_iter()
558                    .map(Into::into)
559                    .collect(),
560            )
561        }
562
563        fn subscribe(
564            &mut self,
565            self_: Resource<IncomingDatagramStream>,
566        ) -> wasmtime::Result<Resource<Pollable>> {
567            AsyncHostIncomingDatagramStream::subscribe(self, self_)
568        }
569
570        fn drop(&mut self, rep: Resource<IncomingDatagramStream>) -> wasmtime::Result<()> {
571            AsyncHostIncomingDatagramStream::drop(self, rep)
572        }
573    }
574
575    impl From<async_udp::IncomingDatagram> for IncomingDatagram {
576        fn from(other: async_udp::IncomingDatagram) -> Self {
577            let async_udp::IncomingDatagram {
578                data,
579                remote_address,
580            } = other;
581            Self {
582                data,
583                remote_address,
584            }
585        }
586    }
587
588    impl HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
589        fn check_send(
590            &mut self,
591            self_: Resource<OutgoingDatagramStream>,
592        ) -> Result<u64, SocketError> {
593            AsyncHostOutgoingDatagramStream::check_send(self, self_)
594        }
595
596        fn send(
597            &mut self,
598            self_: Resource<OutgoingDatagramStream>,
599            datagrams: Vec<OutgoingDatagram>,
600        ) -> Result<u64, SocketError> {
601            let datagrams = datagrams.into_iter().map(Into::into).collect();
602            in_tokio(async { AsyncHostOutgoingDatagramStream::send(self, self_, datagrams).await })
603        }
604
605        fn subscribe(
606            &mut self,
607            self_: Resource<OutgoingDatagramStream>,
608        ) -> wasmtime::Result<Resource<Pollable>> {
609            AsyncHostOutgoingDatagramStream::subscribe(self, self_)
610        }
611
612        fn drop(&mut self, rep: Resource<OutgoingDatagramStream>) -> wasmtime::Result<()> {
613            AsyncHostOutgoingDatagramStream::drop(self, rep)
614        }
615    }
616
617    impl From<OutgoingDatagram> for async_udp::OutgoingDatagram {
618        fn from(other: OutgoingDatagram) -> Self {
619            let OutgoingDatagram {
620                data,
621                remote_address,
622            } = other;
623            Self {
624                data,
625                remote_address,
626            }
627        }
628    }
629}