Skip to main content

wasmtime_wasi/p2/host/
udp.rs

1use crate::p2::bindings::sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network};
2use crate::p2::bindings::sockets::udp;
3use crate::p2::udp::{IncomingDatagramStream, OutgoingDatagramStream};
4use crate::p2::{Pollable, SocketError, SocketResult};
5use crate::sockets::util::{is_valid_address_family, is_valid_remote_address};
6use crate::sockets::{
7    MAX_UDP_DATAGRAM_SIZE, SocketAddrUse, SocketAddressFamily, UdpSocket, WasiSocketsCtxView,
8};
9use async_trait::async_trait;
10use std::net::SocketAddr;
11use std::pin::pin;
12use std::task::{Context, Poll, Waker};
13use tokio::io::Interest;
14use wasmtime::component::Resource;
15use wasmtime::format_err;
16use wasmtime_wasi_io::poll::DynPollable;
17
18impl udp::Host for WasiSocketsCtxView<'_> {}
19
20impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
21    async fn start_bind(
22        &mut self,
23        this: Resource<udp::UdpSocket>,
24        network: Resource<Network>,
25        local_address: IpSocketAddress,
26    ) -> SocketResult<()> {
27        let local_address = SocketAddr::from(local_address);
28        let check = self.table.get(&network)?.socket_addr_check.clone();
29        check.check(local_address, SocketAddrUse::UdpBind).await?;
30
31        let socket = self.table.get_mut(&this)?;
32        socket.bind(local_address)?;
33        socket.set_socket_addr_check(Some(check));
34
35        Ok(())
36    }
37
38    fn finish_bind(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<()> {
39        self.table.get_mut(&this)?.finish_bind()?;
40        Ok(())
41    }
42
43    async fn stream(
44        &mut self,
45        this: Resource<udp::UdpSocket>,
46        remote_address: Option<IpSocketAddress>,
47    ) -> SocketResult<(
48        Resource<udp::IncomingDatagramStream>,
49        Resource<udp::OutgoingDatagramStream>,
50    )> {
51        let has_active_streams = self
52            .table
53            .iter_children(&this)?
54            .any(|c| c.is::<IncomingDatagramStream>() || c.is::<OutgoingDatagramStream>());
55
56        if has_active_streams {
57            return Err(SocketError::trap(format_err!(
58                "UDP streams not dropped yet"
59            )));
60        }
61
62        let socket = self.table.get_mut(&this)?;
63        let remote_address = remote_address.map(SocketAddr::from);
64
65        if !socket.is_bound() {
66            return Err(ErrorCode::InvalidState.into());
67        }
68
69        if let Some(connect_addr) = remote_address {
70            let Some(check) = socket.socket_addr_check() else {
71                return Err(ErrorCode::InvalidState.into());
72            };
73            check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
74            socket.connect_p2(connect_addr)?;
75        } else if socket.is_connected() {
76            socket.disconnect()?;
77        }
78
79        let incoming_stream = IncomingDatagramStream {
80            inner: socket.socket().clone(),
81            remote_address,
82        };
83        let outgoing_stream = OutgoingDatagramStream {
84            inner: socket.socket().clone(),
85            remote_address,
86            family: socket.address_family(),
87            socket_addr_check: socket.socket_addr_check().cloned(),
88            check_send_permit_count: 0,
89        };
90
91        Ok((
92            self.table.push_child(incoming_stream, &this)?,
93            self.table.push_child(outgoing_stream, &this)?,
94        ))
95    }
96
97    fn local_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
98        let socket = self.table.get(&this)?;
99        Ok(socket.local_address()?.into())
100    }
101
102    fn remote_address(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<IpSocketAddress> {
103        let socket = self.table.get(&this)?;
104        Ok(socket.remote_address()?.into())
105    }
106
107    fn address_family(
108        &mut self,
109        this: Resource<udp::UdpSocket>,
110    ) -> Result<IpAddressFamily, wasmtime::Error> {
111        let socket = self.table.get(&this)?;
112        Ok(socket.address_family().into())
113    }
114
115    fn unicast_hop_limit(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u8> {
116        let socket = self.table.get(&this)?;
117        Ok(socket.unicast_hop_limit()?)
118    }
119
120    fn set_unicast_hop_limit(
121        &mut self,
122        this: Resource<udp::UdpSocket>,
123        value: u8,
124    ) -> SocketResult<()> {
125        let socket = self.table.get(&this)?;
126        socket.set_unicast_hop_limit(value)?;
127        Ok(())
128    }
129
130    fn receive_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
131        let socket = self.table.get(&this)?;
132        Ok(socket.receive_buffer_size()?)
133    }
134
135    fn set_receive_buffer_size(
136        &mut self,
137        this: Resource<udp::UdpSocket>,
138        value: u64,
139    ) -> SocketResult<()> {
140        let socket = self.table.get(&this)?;
141        socket.set_receive_buffer_size(value)?;
142        Ok(())
143    }
144
145    fn send_buffer_size(&mut self, this: Resource<udp::UdpSocket>) -> SocketResult<u64> {
146        let socket = self.table.get(&this)?;
147        Ok(socket.send_buffer_size()?)
148    }
149
150    fn set_send_buffer_size(&mut self, this: Resource<UdpSocket>, value: u64) -> SocketResult<()> {
151        let socket = self.table.get(&this)?;
152        socket.set_send_buffer_size(value)?;
153        Ok(())
154    }
155
156    fn subscribe(&mut self, this: Resource<UdpSocket>) -> wasmtime::Result<Resource<DynPollable>> {
157        wasmtime_wasi_io::poll::subscribe(self.table, this)
158    }
159
160    fn drop(&mut self, this: Resource<udp::UdpSocket>) -> Result<(), wasmtime::Error> {
161        // As in the filesystem implementation, we assume closing a socket
162        // doesn't block.
163        let dropped = self.table.delete(this)?;
164        drop(dropped);
165
166        Ok(())
167    }
168}
169
170#[async_trait]
171impl Pollable for UdpSocket {
172    async fn ready(&mut self) {
173        // None of the socket-level operations block natively
174    }
175}
176
177impl udp::HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
178    fn receive(
179        &mut self,
180        this: Resource<udp::IncomingDatagramStream>,
181        max_results: u64,
182    ) -> SocketResult<Vec<udp::IncomingDatagram>> {
183        // Returns Ok(None) when the message was dropped.
184        fn recv_one(
185            stream: &IncomingDatagramStream,
186        ) -> SocketResult<Option<udp::IncomingDatagram>> {
187            let mut buf = [0; MAX_UDP_DATAGRAM_SIZE];
188            let (size, received_addr) = stream.inner.try_recv_from(&mut buf)?;
189            debug_assert!(size <= buf.len());
190
191            match stream.remote_address {
192                Some(connected_addr) if connected_addr != received_addr => {
193                    // Normally, this should have already been checked for us by the OS.
194                    return Ok(None);
195                }
196                _ => {}
197            }
198
199            Ok(Some(udp::IncomingDatagram {
200                data: buf[..size].into(),
201                remote_address: received_addr.into(),
202            }))
203        }
204
205        let stream = self.table.get(&this)?;
206        let max_results: usize = max_results.try_into().unwrap_or(usize::MAX);
207
208        if max_results == 0 {
209            return Ok(vec![]);
210        }
211
212        let mut datagrams = vec![];
213        let mut sum = 0;
214
215        while datagrams.len() < max_results && sum < crate::MAX_READ_SIZE_ALLOC {
216            match recv_one(stream) {
217                Ok(Some(datagram)) => {
218                    sum += 1 + datagram.data.len();
219                    datagrams.push(datagram);
220                }
221                Ok(None) => {
222                    // Message was dropped
223                }
224                Err(_) if datagrams.len() > 0 => {
225                    return Ok(datagrams);
226                }
227                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
228                    return Ok(datagrams);
229                }
230                Err(e) => {
231                    return Err(e);
232                }
233            }
234        }
235
236        Ok(datagrams)
237    }
238
239    fn subscribe(
240        &mut self,
241        this: Resource<udp::IncomingDatagramStream>,
242    ) -> wasmtime::Result<Resource<DynPollable>> {
243        wasmtime_wasi_io::poll::subscribe(self.table, this)
244    }
245
246    fn drop(&mut self, this: Resource<udp::IncomingDatagramStream>) -> Result<(), wasmtime::Error> {
247        // As in the filesystem implementation, we assume closing a socket
248        // doesn't block.
249        let dropped = self.table.delete(this)?;
250        drop(dropped);
251
252        Ok(())
253    }
254}
255
256#[async_trait]
257impl Pollable for IncomingDatagramStream {
258    async fn ready(&mut self) {
259        self.inner
260            .ready(Interest::READABLE.add(Interest::ERROR))
261            .await
262            .expect("failed to await UDP socket readiness");
263    }
264}
265
266impl udp::HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
267    fn check_send(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> SocketResult<u64> {
268        let stream = self.table.get_mut(&this)?;
269
270        let count = if let Poll::Ready(_) =
271            pin!(stream.inner.ready(Interest::WRITABLE.add(Interest::ERROR)))
272                .poll(&mut Context::from_waker(Waker::noop()))
273        {
274            // We don't know how many Tokio will accept, so we make up a
275            // reasonable number here.  If we're wrong and `send` returns
276            // `Ok(0)`, the guest will just have to deal with that, e.g. by
277            // looping or returning `EWOULDBLOCK`.
278            16
279        } else {
280            0
281        };
282
283        stream.check_send_permit_count = count;
284
285        Ok(count.try_into().unwrap())
286    }
287
288    async fn send(
289        &mut self,
290        this: Resource<udp::OutgoingDatagramStream>,
291        datagrams: Vec<udp::OutgoingDatagram>,
292    ) -> SocketResult<u64> {
293        async fn send_one(
294            stream: &OutgoingDatagramStream,
295            datagram: &udp::OutgoingDatagram,
296        ) -> SocketResult<()> {
297            if datagram.data.len() > MAX_UDP_DATAGRAM_SIZE {
298                return Err(ErrorCode::DatagramTooLarge.into());
299            }
300
301            let provided_addr = datagram.remote_address.map(SocketAddr::from);
302            let addr = match (stream.remote_address, provided_addr) {
303                (None, Some(addr)) => {
304                    let Some(check) = stream.socket_addr_check.as_ref() else {
305                        return Err(ErrorCode::InvalidState.into());
306                    };
307                    check
308                        .check(addr, SocketAddrUse::UdpOutgoingDatagram)
309                        .await?;
310                    addr
311                }
312                (Some(addr), None) => addr,
313                (Some(connected_addr), Some(provided_addr)) if connected_addr == provided_addr => {
314                    connected_addr
315                }
316                _ => return Err(ErrorCode::InvalidArgument.into()),
317            };
318
319            if !is_valid_remote_address(addr) || !is_valid_address_family(addr.ip(), stream.family)
320            {
321                return Err(ErrorCode::InvalidArgument.into());
322            }
323
324            if stream.remote_address == Some(addr) {
325                stream.inner.try_send(&datagram.data)?;
326            } else {
327                stream.inner.try_send_to(&datagram.data, addr)?;
328            }
329
330            Ok(())
331        }
332
333        let stream = self.table.get_mut(&this)?;
334
335        if datagrams.is_empty() {
336            return Ok(0);
337        }
338
339        if datagrams.len() > stream.check_send_permit_count {
340            return Err(SocketError::trap(wasmtime::format_err!(
341                "unpermitted: argument exceeds permitted size"
342            )));
343        }
344
345        stream.check_send_permit_count -= datagrams.len();
346
347        let mut count = 0;
348
349        for datagram in datagrams {
350            match send_one(stream, &datagram).await {
351                Ok(_) => count += 1,
352                Err(_) if count > 0 => {
353                    // WIT: "If at least one datagram has been sent successfully, this function never returns an error."
354                    return Ok(count);
355                }
356                Err(e) if matches!(e.downcast_ref(), Some(ErrorCode::WouldBlock)) => {
357                    debug_assert!(count == 0);
358                    return Ok(0);
359                }
360                Err(e) => {
361                    return Err(e);
362                }
363            }
364        }
365
366        Ok(count)
367    }
368
369    fn subscribe(
370        &mut self,
371        this: Resource<udp::OutgoingDatagramStream>,
372    ) -> wasmtime::Result<Resource<DynPollable>> {
373        wasmtime_wasi_io::poll::subscribe(self.table, this)
374    }
375
376    fn drop(&mut self, this: Resource<udp::OutgoingDatagramStream>) -> Result<(), wasmtime::Error> {
377        // As in the filesystem implementation, we assume closing a socket
378        // doesn't block.
379        let dropped = self.table.delete(this)?;
380        drop(dropped);
381
382        Ok(())
383    }
384}
385
386#[async_trait]
387impl Pollable for OutgoingDatagramStream {
388    async fn ready(&mut self) {
389        self.inner
390            .ready(Interest::WRITABLE.add(Interest::ERROR))
391            .await
392            .expect("failed to await UDP socket readiness");
393    }
394}
395
396impl From<SocketAddressFamily> for IpAddressFamily {
397    fn from(family: SocketAddressFamily) -> IpAddressFamily {
398        match family {
399            SocketAddressFamily::Ipv4 => IpAddressFamily::Ipv4,
400            SocketAddressFamily::Ipv6 => IpAddressFamily::Ipv6,
401        }
402    }
403}
404
405pub mod sync {
406    use wasmtime::component::Resource;
407
408    use crate::p2::{
409        SocketError,
410        bindings::{
411            sockets::{
412                network::Network,
413                udp::{
414                    self as async_udp,
415                    HostIncomingDatagramStream as AsyncHostIncomingDatagramStream,
416                    HostOutgoingDatagramStream as AsyncHostOutgoingDatagramStream,
417                    HostUdpSocket as AsyncHostUdpSocket, IncomingDatagramStream,
418                    OutgoingDatagramStream,
419                },
420            },
421            sync::sockets::udp::{
422                self, HostIncomingDatagramStream, HostOutgoingDatagramStream, HostUdpSocket,
423                IncomingDatagram, IpAddressFamily, IpSocketAddress, OutgoingDatagram, Pollable,
424                UdpSocket,
425            },
426        },
427    };
428    use crate::runtime::in_tokio;
429    use crate::sockets::WasiSocketsCtxView;
430
431    impl udp::Host for WasiSocketsCtxView<'_> {}
432
433    impl HostUdpSocket for WasiSocketsCtxView<'_> {
434        fn start_bind(
435            &mut self,
436            self_: Resource<UdpSocket>,
437            network: Resource<Network>,
438            local_address: IpSocketAddress,
439        ) -> Result<(), SocketError> {
440            in_tokio(async {
441                AsyncHostUdpSocket::start_bind(self, self_, network, local_address).await
442            })
443        }
444
445        fn finish_bind(&mut self, self_: Resource<UdpSocket>) -> Result<(), SocketError> {
446            AsyncHostUdpSocket::finish_bind(self, self_)
447        }
448
449        fn stream(
450            &mut self,
451            self_: Resource<UdpSocket>,
452            remote_address: Option<IpSocketAddress>,
453        ) -> Result<
454            (
455                Resource<IncomingDatagramStream>,
456                Resource<OutgoingDatagramStream>,
457            ),
458            SocketError,
459        > {
460            in_tokio(async { AsyncHostUdpSocket::stream(self, self_, remote_address).await })
461        }
462
463        fn local_address(
464            &mut self,
465            self_: Resource<UdpSocket>,
466        ) -> Result<IpSocketAddress, SocketError> {
467            AsyncHostUdpSocket::local_address(self, self_)
468        }
469
470        fn remote_address(
471            &mut self,
472            self_: Resource<UdpSocket>,
473        ) -> Result<IpSocketAddress, SocketError> {
474            AsyncHostUdpSocket::remote_address(self, self_)
475        }
476
477        fn address_family(
478            &mut self,
479            self_: Resource<UdpSocket>,
480        ) -> wasmtime::Result<IpAddressFamily> {
481            AsyncHostUdpSocket::address_family(self, self_)
482        }
483
484        fn unicast_hop_limit(&mut self, self_: Resource<UdpSocket>) -> Result<u8, SocketError> {
485            AsyncHostUdpSocket::unicast_hop_limit(self, self_)
486        }
487
488        fn set_unicast_hop_limit(
489            &mut self,
490            self_: Resource<UdpSocket>,
491            value: u8,
492        ) -> Result<(), SocketError> {
493            AsyncHostUdpSocket::set_unicast_hop_limit(self, self_, value)
494        }
495
496        fn receive_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
497            AsyncHostUdpSocket::receive_buffer_size(self, self_)
498        }
499
500        fn set_receive_buffer_size(
501            &mut self,
502            self_: Resource<UdpSocket>,
503            value: u64,
504        ) -> Result<(), SocketError> {
505            AsyncHostUdpSocket::set_receive_buffer_size(self, self_, value)
506        }
507
508        fn send_buffer_size(&mut self, self_: Resource<UdpSocket>) -> Result<u64, SocketError> {
509            AsyncHostUdpSocket::send_buffer_size(self, self_)
510        }
511
512        fn set_send_buffer_size(
513            &mut self,
514            self_: Resource<UdpSocket>,
515            value: u64,
516        ) -> Result<(), SocketError> {
517            AsyncHostUdpSocket::set_send_buffer_size(self, self_, value)
518        }
519
520        fn subscribe(
521            &mut self,
522            self_: Resource<UdpSocket>,
523        ) -> wasmtime::Result<Resource<Pollable>> {
524            AsyncHostUdpSocket::subscribe(self, self_)
525        }
526
527        fn drop(&mut self, rep: Resource<UdpSocket>) -> wasmtime::Result<()> {
528            AsyncHostUdpSocket::drop(self, rep)
529        }
530    }
531
532    impl HostIncomingDatagramStream for WasiSocketsCtxView<'_> {
533        fn receive(
534            &mut self,
535            self_: Resource<IncomingDatagramStream>,
536            max_results: u64,
537        ) -> Result<Vec<IncomingDatagram>, SocketError> {
538            Ok(
539                AsyncHostIncomingDatagramStream::receive(self, self_, max_results)?
540                    .into_iter()
541                    .map(Into::into)
542                    .collect(),
543            )
544        }
545
546        fn subscribe(
547            &mut self,
548            self_: Resource<IncomingDatagramStream>,
549        ) -> wasmtime::Result<Resource<Pollable>> {
550            AsyncHostIncomingDatagramStream::subscribe(self, self_)
551        }
552
553        fn drop(&mut self, rep: Resource<IncomingDatagramStream>) -> wasmtime::Result<()> {
554            AsyncHostIncomingDatagramStream::drop(self, rep)
555        }
556    }
557
558    impl From<async_udp::IncomingDatagram> for IncomingDatagram {
559        fn from(other: async_udp::IncomingDatagram) -> Self {
560            let async_udp::IncomingDatagram {
561                data,
562                remote_address,
563            } = other;
564            Self {
565                data,
566                remote_address,
567            }
568        }
569    }
570
571    impl HostOutgoingDatagramStream for WasiSocketsCtxView<'_> {
572        fn check_send(
573            &mut self,
574            self_: Resource<OutgoingDatagramStream>,
575        ) -> Result<u64, SocketError> {
576            AsyncHostOutgoingDatagramStream::check_send(self, self_)
577        }
578
579        fn send(
580            &mut self,
581            self_: Resource<OutgoingDatagramStream>,
582            datagrams: Vec<OutgoingDatagram>,
583        ) -> Result<u64, SocketError> {
584            let datagrams = datagrams.into_iter().map(Into::into).collect();
585            in_tokio(async { AsyncHostOutgoingDatagramStream::send(self, self_, datagrams).await })
586        }
587
588        fn subscribe(
589            &mut self,
590            self_: Resource<OutgoingDatagramStream>,
591        ) -> wasmtime::Result<Resource<Pollable>> {
592            AsyncHostOutgoingDatagramStream::subscribe(self, self_)
593        }
594
595        fn drop(&mut self, rep: Resource<OutgoingDatagramStream>) -> wasmtime::Result<()> {
596            AsyncHostOutgoingDatagramStream::drop(self, rep)
597        }
598    }
599
600    impl From<OutgoingDatagram> for async_udp::OutgoingDatagram {
601        fn from(other: OutgoingDatagram) -> Self {
602            let OutgoingDatagram {
603                data,
604                remote_address,
605            } = other;
606            Self {
607                data,
608                remote_address,
609            }
610        }
611    }
612}