1use crate::{
2 Error, ErrorExt,
3 file::{FdFlags, FileType, RiFlags, RoFlags, SdFlags, SiFlags, WasiFile},
4};
5#[cfg(windows)]
6use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket};
7use io_lifetimes::AsSocketlike;
8#[cfg(unix)]
9use io_lifetimes::{AsFd, BorrowedFd};
10#[cfg(windows)]
11use io_lifetimes::{AsSocket, BorrowedSocket};
12use std::any::Any;
13use std::io;
14#[cfg(unix)]
15use system_interface::fs::GetSetFdFlags;
16use system_interface::io::IoExt;
17use system_interface::io::IsReadWrite;
18use system_interface::io::ReadReady;
19
20pub enum Socket {
21 TcpListener(cap_std::net::TcpListener),
22 TcpStream(cap_std::net::TcpStream),
23 #[cfg(unix)]
24 UnixStream(cap_std::os::unix::net::UnixStream),
25 #[cfg(unix)]
26 UnixListener(cap_std::os::unix::net::UnixListener),
27}
28
29impl From<cap_std::net::TcpListener> for Socket {
30 fn from(listener: cap_std::net::TcpListener) -> Self {
31 Self::TcpListener(listener)
32 }
33}
34
35impl From<cap_std::net::TcpStream> for Socket {
36 fn from(stream: cap_std::net::TcpStream) -> Self {
37 Self::TcpStream(stream)
38 }
39}
40
41#[cfg(unix)]
42impl From<cap_std::os::unix::net::UnixListener> for Socket {
43 fn from(listener: cap_std::os::unix::net::UnixListener) -> Self {
44 Self::UnixListener(listener)
45 }
46}
47
48#[cfg(unix)]
49impl From<cap_std::os::unix::net::UnixStream> for Socket {
50 fn from(stream: cap_std::os::unix::net::UnixStream) -> Self {
51 Self::UnixStream(stream)
52 }
53}
54
55#[cfg(unix)]
56impl From<Socket> for Box<dyn WasiFile> {
57 fn from(listener: Socket) -> Self {
58 match listener {
59 Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
60 Socket::UnixListener(l) => Box::new(crate::sync::net::UnixListener::from_cap_std(l)),
61 Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
62 Socket::UnixStream(l) => Box::new(crate::sync::net::UnixStream::from_cap_std(l)),
63 }
64 }
65}
66
67#[cfg(windows)]
68impl From<Socket> for Box<dyn WasiFile> {
69 fn from(listener: Socket) -> Self {
70 match listener {
71 Socket::TcpListener(l) => Box::new(crate::sync::net::TcpListener::from_cap_std(l)),
72 Socket::TcpStream(l) => Box::new(crate::sync::net::TcpStream::from_cap_std(l)),
73 }
74 }
75}
76
77macro_rules! wasi_listen_write_impl {
78 ($ty:ty, $stream:ty) => {
79 #[wiggle::async_trait]
80 impl WasiFile for $ty {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84 #[cfg(unix)]
85 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
86 Some(self.0.as_fd())
87 }
88 #[cfg(windows)]
89 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
90 Some(self.0.as_raw_handle_or_socket())
91 }
92 async fn sock_accept(&self, fdflags: FdFlags) -> Result<Box<dyn WasiFile>, Error> {
93 let (stream, _) = self.0.accept()?;
94 let mut stream = <$stream>::from_cap_std(stream);
95 stream.set_fdflags(fdflags).await?;
96 Ok(Box::new(stream))
97 }
98 async fn get_filetype(&self) -> Result<FileType, Error> {
99 Ok(FileType::SocketStream)
100 }
101 #[cfg(unix)]
102 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
103 let fdflags = get_fd_flags(&self.0)?;
104 Ok(fdflags)
105 }
106 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
107 if fdflags == crate::file::FdFlags::NONBLOCK {
108 self.0.set_nonblocking(true)?;
109 } else if fdflags.is_empty() {
110 self.0.set_nonblocking(false)?;
111 } else {
112 return Err(
113 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
114 );
115 }
116 Ok(())
117 }
118 fn num_ready_bytes(&self) -> Result<u64, Error> {
119 Ok(1)
120 }
121 }
122
123 #[cfg(windows)]
124 impl AsSocket for $ty {
125 #[inline]
126 fn as_socket(&self) -> BorrowedSocket<'_> {
127 self.0.as_socket()
128 }
129 }
130
131 #[cfg(windows)]
132 impl AsRawHandleOrSocket for $ty {
133 #[inline]
134 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
135 self.0.as_raw_handle_or_socket()
136 }
137 }
138
139 #[cfg(unix)]
140 impl AsFd for $ty {
141 fn as_fd(&self) -> BorrowedFd<'_> {
142 self.0.as_fd()
143 }
144 }
145 };
146}
147
148pub struct TcpListener(cap_std::net::TcpListener);
149
150impl TcpListener {
151 pub fn from_cap_std(cap_std: cap_std::net::TcpListener) -> Self {
152 TcpListener(cap_std)
153 }
154}
155wasi_listen_write_impl!(TcpListener, TcpStream);
156
157#[cfg(unix)]
158pub struct UnixListener(cap_std::os::unix::net::UnixListener);
159
160#[cfg(unix)]
161impl UnixListener {
162 pub fn from_cap_std(cap_std: cap_std::os::unix::net::UnixListener) -> Self {
163 UnixListener(cap_std)
164 }
165}
166
167#[cfg(unix)]
168wasi_listen_write_impl!(UnixListener, UnixStream);
169
170macro_rules! wasi_stream_write_impl {
171 ($ty:ty, $std_ty:ty) => {
172 #[wiggle::async_trait]
173 impl WasiFile for $ty {
174 fn as_any(&self) -> &dyn Any {
175 self
176 }
177 #[cfg(unix)]
178 fn pollable(&self) -> Option<rustix::fd::BorrowedFd> {
179 Some(self.0.as_fd())
180 }
181 #[cfg(windows)]
182 fn pollable(&self) -> Option<io_extras::os::windows::RawHandleOrSocket> {
183 Some(self.0.as_raw_handle_or_socket())
184 }
185 async fn get_filetype(&self) -> Result<FileType, Error> {
186 Ok(FileType::SocketStream)
187 }
188 #[cfg(unix)]
189 async fn get_fdflags(&self) -> Result<FdFlags, Error> {
190 let fdflags = get_fd_flags(&self.0)?;
191 Ok(fdflags)
192 }
193 async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> {
194 if fdflags == crate::file::FdFlags::NONBLOCK {
195 self.0.set_nonblocking(true)?;
196 } else if fdflags.is_empty() {
197 self.0.set_nonblocking(false)?;
198 } else {
199 return Err(
200 Error::invalid_argument().context("cannot set anything else than NONBLOCK")
201 );
202 }
203 Ok(())
204 }
205 async fn read_vectored<'a>(
206 &self,
207 bufs: &mut [io::IoSliceMut<'a>],
208 ) -> Result<u64, Error> {
209 use std::io::Read;
210 let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
211 Ok(n.try_into()?)
212 }
213 async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result<u64, Error> {
214 use std::io::Write;
215 let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?;
216 Ok(n.try_into()?)
217 }
218 async fn peek(&self, buf: &mut [u8]) -> Result<u64, Error> {
219 let n = self.0.peek(buf)?;
220 Ok(n.try_into()?)
221 }
222 fn num_ready_bytes(&self) -> Result<u64, Error> {
223 let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?;
224 Ok(val)
225 }
226 async fn readable(&self) -> Result<(), Error> {
227 let (readable, _writeable) = is_read_write(&self.0)?;
228 if readable { Ok(()) } else { Err(Error::io()) }
229 }
230 async fn writable(&self) -> Result<(), Error> {
231 let (_readable, writeable) = is_read_write(&self.0)?;
232 if writeable { Ok(()) } else { Err(Error::io()) }
233 }
234
235 async fn sock_recv<'a>(
236 &self,
237 ri_data: &mut [std::io::IoSliceMut<'a>],
238 ri_flags: RiFlags,
239 ) -> Result<(u64, RoFlags), Error> {
240 if (ri_flags & !(RiFlags::RECV_PEEK | RiFlags::RECV_WAITALL)) != RiFlags::empty() {
241 return Err(Error::not_supported());
242 }
243
244 if ri_flags.contains(RiFlags::RECV_PEEK) {
245 if let Some(first) = ri_data.iter_mut().next() {
246 let n = self.0.peek(first)?;
247 return Ok((n as u64, RoFlags::empty()));
248 } else {
249 return Ok((0, RoFlags::empty()));
250 }
251 }
252
253 if ri_flags.contains(RiFlags::RECV_WAITALL) {
254 let n: usize = ri_data.iter().map(|buf| buf.len()).sum();
255 self.0.read_exact_vectored(ri_data)?;
256 return Ok((n as u64, RoFlags::empty()));
257 }
258
259 let n = self.0.read_vectored(ri_data)?;
260 Ok((n as u64, RoFlags::empty()))
261 }
262
263 async fn sock_send<'a>(
264 &self,
265 si_data: &[std::io::IoSlice<'a>],
266 si_flags: SiFlags,
267 ) -> Result<u64, Error> {
268 if si_flags != SiFlags::empty() {
269 return Err(Error::not_supported());
270 }
271
272 let n = self.0.write_vectored(si_data)?;
273 Ok(n as u64)
274 }
275
276 async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> {
277 let how = if how == SdFlags::RD | SdFlags::WR {
278 cap_std::net::Shutdown::Both
279 } else if how == SdFlags::RD {
280 cap_std::net::Shutdown::Read
281 } else if how == SdFlags::WR {
282 cap_std::net::Shutdown::Write
283 } else {
284 return Err(Error::invalid_argument());
285 };
286 self.0.shutdown(how)?;
287 Ok(())
288 }
289 }
290 #[cfg(unix)]
291 impl AsFd for $ty {
292 fn as_fd(&self) -> BorrowedFd<'_> {
293 self.0.as_fd()
294 }
295 }
296
297 #[cfg(windows)]
298 impl AsSocket for $ty {
299 fn as_socket(&self) -> BorrowedSocket<'_> {
301 self.0.as_socket()
302 }
303 }
304
305 #[cfg(windows)]
306 impl AsRawHandleOrSocket for TcpStream {
307 #[inline]
308 fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket {
309 self.0.as_raw_handle_or_socket()
310 }
311 }
312 };
313}
314
315pub struct TcpStream(cap_std::net::TcpStream);
316
317impl TcpStream {
318 pub fn from_cap_std(socket: cap_std::net::TcpStream) -> Self {
319 TcpStream(socket)
320 }
321}
322
323wasi_stream_write_impl!(TcpStream, std::net::TcpStream);
324
325#[cfg(unix)]
326pub struct UnixStream(cap_std::os::unix::net::UnixStream);
327
328#[cfg(unix)]
329impl UnixStream {
330 pub fn from_cap_std(socket: cap_std::os::unix::net::UnixStream) -> Self {
331 UnixStream(socket)
332 }
333}
334
335#[cfg(unix)]
336wasi_stream_write_impl!(UnixStream, std::os::unix::net::UnixStream);
337
338pub fn filetype_from(ft: &cap_std::fs::FileType) -> FileType {
339 use cap_fs_ext::FileTypeExt;
340 if ft.is_block_device() {
341 FileType::SocketDgram
342 } else {
343 FileType::SocketStream
344 }
345}
346
347pub fn get_fd_flags<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<crate::file::FdFlags> {
351 #[cfg(not(windows))]
354 {
355 let mut out = crate::file::FdFlags::empty();
356 if f.get_fd_flags()?
357 .contains(system_interface::fs::FdFlags::NONBLOCK)
358 {
359 out |= crate::file::FdFlags::NONBLOCK;
360 }
361 Ok(out)
362 }
363
364 #[cfg(windows)]
368 let buf: &mut [u8] = &mut [];
369 #[cfg(windows)]
370 match rustix::net::recv(f, buf, rustix::net::RecvFlags::empty()) {
371 Ok(_) => Ok(crate::file::FdFlags::empty()),
372 Err(rustix::io::Errno::WOULDBLOCK) => Ok(crate::file::FdFlags::NONBLOCK),
373 Err(e) => Err(e.into()),
374 }
375}
376
377pub fn is_read_write<Socketlike: AsSocketlike>(f: Socketlike) -> io::Result<(bool, bool)> {
381 #[cfg(not(windows))]
383 {
384 f.is_read_write()
385 }
386
387 #[cfg(windows)]
389 {
390 f.as_socketlike_view::<std::net::TcpStream>()
391 .is_read_write()
392 }
393}