wasi_common/sync/
net.rs

1use crate::{
2    Error, ErrorExt,
3    file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
4};
5#[cfg(windows)]
6use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
7use io_lifetimes::AsSocketlike;
8#[cfg(unix)]
9use io_lifetimes::{AsFd, BorrowedFd};
10#[cfg(windows)]
11use io_lifetimes::{AsSocket, BorrowedSocket};
12use std::any::Any;
13use std::io;
14#[cfg(unix)]
15use system_interface::fs::GetSetFdFlags;
16use system_interface::io::IoExt;
17use system_interface::io::IsReadWrite;
18use system_interface::io::ReadReady;
19
20pub enum Socket {
21    TcpListener(cap_std::net::TcpListener),
22    TcpStream(cap_std::net::TcpStream),
23    #[cfg(unix)]
24    UnixStream(cap_std::os::unix::net::UnixStream),
25    #[cfg(unix)]
26    UnixListener(cap_std::os::unix::net::UnixListener),
27}
28
29impl From<cap_std::net::TcpListener> for Socket {
30    fn from(listener: cap_std::net::TcpListener) -> Self {
31        Self::TcpListener(listener)
32    }
33}
34
35impl From<cap_std::net::TcpStream> for Socket {
36    fn from(stream: cap_std::net::TcpStream) -> Self {
37        Self::TcpStream(stream)
38    }
39}
40
41#[cfg(unix)]
42impl From<cap_std::os::unix::net::UnixListener> for Socket {
43    fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
44        Self::UnixListener(listener)
45    }
46}
47
48#[cfg(unix)]
49impl From<cap_std::os::unix::net::UnixStream> for Socket {
50    fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
51        Self::UnixStream(stream)
52    }
53}
54
55#[cfg(unix)]
56impl From<Socket> for Box<dyn WasiFile> {
57    fn from(listener: Socket) -> Self {
58        match listener {
59            Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
60            Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),
61            Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
62            Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),
63        }
64    }
65}
66
67#[cfg(windows)]
68impl From<Socket> for Box<dyn WasiFile> {
69    fn from(listener: Socket) -> Self {
70        match listener {
71            Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
72            Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
73        }
74    }
75}
76
77macro_rules! wasi_listen_write_impl {
78    ($ty:ty, $stream:ty) => {
79        #[wiggle::async_trait]
80        impl WasiFile for $ty {
81            fn as_any(&self) -> &dyn Any {
82                self
83            }
84            #[cfg(unix)]
85            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
86                Some(self.0.as_fd())
87            }
88            #[cfg(windows)]
89            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
90                Some(self.0.as_raw_handle_or_socket())
91            }
92            async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
93                let (stream, _) = self.0.accept()?;
94                let mut stream = <$stream>::from_cap_std(stream);
95                stream.set_fdflags(fdflags).await?;
96                Ok(Box::new(stream))
97            }
98            async fn get_filetype(&self) -> Result<FileType, Error> {
99                Ok(FileType::SocketStream)
100            }
101            #[cfg(unix)]
102            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
103                let fdflags = get_fd_flags(&self.0)?;
104                Ok(fdflags)
105            }
106            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
107                if fdflags == crate::file::FdFlags::NONBLOCK {
108                    self.0.set_nonblocking(true)?;
109                } else if fdflags.is_empty() {
110                    self.0.set_nonblocking(false)?;
111                } else {
112                    return Err(
113                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
114                    );
115                }
116                Ok(())
117            }
118            fn num_ready_bytes(&self) -> Result<u64, Error> {
119                Ok(1)
120            }
121        }
122
123        #[cfg(windows)]
124        impl AsSocket for $ty {
125            #[inline]
126            fn as_socket(&self) -> BorrowedSocket<'_> {
127                self.0.as_socket()
128            }
129        }
130
131        #[cfg(windows)]
132        impl AsRawHandleOrSocket for $ty {
133            #[inline]
134            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
135                self.0.as_raw_handle_or_socket()
136            }
137        }
138
139        #[cfg(unix)]
140        impl AsFd for $ty {
141            fn as_fd(&self) -> BorrowedFd<'_> {
142                self.0.as_fd()
143            }
144        }
145    };
146}
147
148pub struct TcpListener(cap_std::net::TcpListener);
149
150impl TcpListener {
151    pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
152        TcpListener(cap_std)
153    }
154}
155wasi_listen_write_impl!(TcpListener, TcpStream);
156
157#[cfg(unix)]
158pub struct UnixListener(cap_std::os::unix::net::UnixListener);
159
160#[cfg(unix)]
161impl UnixListener {
162    pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
163        UnixListener(cap_std)
164    }
165}
166
167#[cfg(unix)]
168wasi_listen_write_impl!(UnixListener, UnixStream);
169
170macro_rules! wasi_stream_write_impl {
171    ($ty:ty, $std_ty:ty) => {
172        #[wiggle::async_trait]
173        impl WasiFile for $ty {
174            fn as_any(&self) -> &dyn Any {
175                self
176            }
177            #[cfg(unix)]
178            fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
179                Some(self.0.as_fd())
180            }
181            #[cfg(windows)]
182            fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
183                Some(self.0.as_raw_handle_or_socket())
184            }
185            async fn get_filetype(&self) -> Result<FileType, Error> {
186                Ok(FileType::SocketStream)
187            }
188            #[cfg(unix)]
189            async fn get_fdflags(&self) -> Result<FdFlags, Error> {
190                let fdflags = get_fd_flags(&self.0)?;
191                Ok(fdflags)
192            }
193            async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
194                if fdflags == crate::file::FdFlags::NONBLOCK {
195                    self.0.set_nonblocking(true)?;
196                } else if fdflags.is_empty() {
197                    self.0.set_nonblocking(false)?;
198                } else {
199                    return Err(
200                        Error::invalid_argument().context("cannot set anything else than NONBLOCK")
201                    );
202                }
203                Ok(())
204            }
205            async fn read_vectored<'a>(
206                &self,
207                bufs: &mut [io::IoSliceMut<'a>],
208            ) -> Result<u64, Error> {
209                use std::io::Read;
210                let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
211                Ok(n.try_into()?)
212            }
213            async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
214                use std::io::Write;
215                let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
216                Ok(n.try_into()?)
217            }
218            async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
219                let n = self.0.peek(buf)?;
220                Ok(n.try_into()?)
221            }
222            fn num_ready_bytes(&self) -> Result<u64, Error> {
223                let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
224                Ok(val)
225            }
226            async fn readable(&self) -> Result<(), Error> {
227                let (readable, _writeable) = is_read_write(&self.0)?;
228                if readable { Ok(()) } else { Err(Error::io()) }
229            }
230            async fn writable(&self) -> Result<(), Error> {
231                let (_readable, writeable) = is_read_write(&self.0)?;
232                if writeable { Ok(()) } else { Err(Error::io()) }
233            }
234
235            async fn sock_recv<'a>(
236                &self,
237                ri_data: &mut [std::io::IoSliceMut<'a>],
238                ri_flags: RiFlags,
239            ) -> Result<(u64, RoFlags), Error> {
240                if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
241                    return Err(Error::not_supported());
242                }
243
244                if ri_flags.contains(RiFlags::RECV_PEEK) {
245                    if let Some(first) = ri_data.iter_mut().next() {
246                        let n = self.0.peek(first)?;
247                        return Ok((n as u64, RoFlags::empty()));
248                    } else {
249                        return Ok((0, RoFlags::empty()));
250                    }
251                }
252
253                if ri_flags.contains(RiFlags::RECV_WAITALL) {
254                    let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
255                    self.0.read_exact_vectored(ri_data)?;
256                    return Ok((n as u64, RoFlags::empty()));
257                }
258
259                let n = self.0.read_vectored(ri_data)?;
260                Ok((n as u64, RoFlags::empty()))
261            }
262
263            async fn sock_send<'a>(
264                &self,
265                si_data: &[std::io::IoSlice<'a>],
266                si_flags: SiFlags,
267            ) -> Result<u64, Error> {
268                if si_flags != SiFlags::empty() {
269                    return Err(Error::not_supported());
270                }
271
272                let n = self.0.write_vectored(si_data)?;
273                Ok(n as u64)
274            }
275
276            async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
277                let how = if how == SdFlags::RD | SdFlags::WR {
278                    cap_std::net::Shutdown::Both
279                } else if how == SdFlags::RD {
280                    cap_std::net::Shutdown::Read
281                } else if how == SdFlags::WR {
282                    cap_std::net::Shutdown::Write
283                } else {
284                    return Err(Error::invalid_argument());
285                };
286                self.0.shutdown(how)?;
287                Ok(())
288            }
289        }
290        #[cfg(unix)]
291        impl AsFd for $ty {
292            fn as_fd(&self) -> BorrowedFd<'_> {
293                self.0.as_fd()
294            }
295        }
296
297        #[cfg(windows)]
298        impl AsSocket for $ty {
299            /// Borrows the socket.
300            fn as_socket(&self) -> BorrowedSocket<'_> {
301                self.0.as_socket()
302            }
303        }
304
305        #[cfg(windows)]
306        impl AsRawHandleOrSocket for TcpStream {
307            #[inline]
308            fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
309                self.0.as_raw_handle_or_socket()
310            }
311        }
312    };
313}
314
315pub struct TcpStream(cap_std::net::TcpStream);
316
317impl TcpStream {
318    pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
319        TcpStream(socket)
320    }
321}
322
323wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
324
325#[cfg(unix)]
326pub struct UnixStream(cap_std::os::unix::net::UnixStream);
327
328#[cfg(unix)]
329impl UnixStream {
330    pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
331        UnixStream(socket)
332    }
333}
334
335#[cfg(unix)]
336wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
337
338pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
339    use cap_fs_ext::FileTypeExt;
340    if ft.is_block_device() {
341        FileType::SocketDgram
342    } else {
343        FileType::SocketStream
344    }
345}
346
347/// Return the file-descriptor flags for a given file-like object.
348///
349/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
350pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {
351    // On Unix-family platforms, we can use the same system call that we'd use
352    // for files on sockets here.
353    #[cfg(not(windows))]
354    {
355        let mut out = crate::file::FdFlags::empty();
356        if f.get_fd_flags()?
357            .contains(system_interface::fs::FdFlags::NONBLOCK)
358        {
359            out |= crate::file::FdFlags::NONBLOCK;
360        }
361        Ok(out)
362    }
363
364    // On Windows, sockets are different, and there is no direct way to
365    // query for the non-blocking flag. We can get a sufficient approximation
366    // by testing whether a zero-length `recv` appears to block.
367    #[cfg(windows)]
368    let buf: &mut [u8] = &mut [];
369    #[cfg(windows)]
370    match rustix::net::recv(f, buf, rustix::net::RecvFlags::empty()) {
371        Ok(_) => Ok(crate::file::FdFlags::empty()),
372        Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),
373        Err(e) => Err(e.into()),
374    }
375}
376
377/// Return the file-descriptor flags for a given file-like object.
378///
379/// This returns the flags needed to implement [`WasiFile::get_fdflags`].
380pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
381    // On Unix-family platforms, we have an `IsReadWrite` impl.
382    #[cfg(not(windows))]
383    {
384        f.is_read_write()
385    }
386
387    // On Windows, we only have a `TcpStream` impl, so make a view first.
388    #[cfg(windows)]
389    {
390        f.as_socketlike_view::<std::net::TcpStream>()
391            .is_read_write()
392    }
393}