wasmtime_wasi/p2/
tcp.rs

1use crate::net::{SocketAddressFamily, DEFAULT_TCP_BACKLOG};
2use crate::p2::bindings::sockets::tcp::ErrorCode;
3use crate::p2::host::network;
4use crate::p2::{
5    DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
6    SocketResult, StreamError,
7};
8use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
9use anyhow::Result;
10use cap_net_ext::AddressFamily;
11use futures::Future;
12use io_lifetimes::views::SocketlikeView;
13use io_lifetimes::AsSocketlike;
14use rustix::io::Errno;
15use rustix::net::sockopt;
16use std::io;
17use std::mem;
18use std::net::{Shutdown, SocketAddr};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Poll;
22use tokio::sync::Mutex;
23
24/// The state of a TCP socket.
25///
26/// This represents the various states a socket can be in during the
27/// activities of binding, listening, accepting, and connecting.
28enum TcpState {
29    /// The initial state for a newly-created socket.
30    Default(tokio::net::TcpSocket),
31
32    /// Binding started via `start_bind`.
33    BindStarted(tokio::net::TcpSocket),
34
35    /// Binding finished via `finish_bind`. The socket has an address but
36    /// is not yet listening for connections.
37    Bound(tokio::net::TcpSocket),
38
39    /// Listening started via `listen_start`.
40    ListenStarted(tokio::net::TcpSocket),
41
42    /// The socket is now listening and waiting for an incoming connection.
43    Listening {
44        listener: tokio::net::TcpListener,
45        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
46    },
47
48    /// An outgoing connection is started via `start_connect`.
49    Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
50
51    /// An outgoing connection is ready to be established.
52    ConnectReady(io::Result<tokio::net::TcpStream>),
53
54    /// An outgoing connection has been established.
55    Connected {
56        stream: Arc<tokio::net::TcpStream>,
57
58        // WASI is single threaded, so in practice these Mutexes should never be contended:
59        reader: Arc<Mutex<TcpReader>>,
60        writer: Arc<Mutex<TcpWriter>>,
61    },
62
63    Closed,
64}
65
66impl std::fmt::Debug for TcpState {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            Self::Default(_) => f.debug_tuple("Default").finish(),
70            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
71            Self::Bound(_) => f.debug_tuple("Bound").finish(),
72            Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
73            Self::Listening { pending_accept, .. } => f
74                .debug_struct("Listening")
75                .field("pending_accept", pending_accept)
76                .finish(),
77            Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
78            Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
79            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
80            Self::Closed => write!(f, "Closed"),
81        }
82    }
83}
84
85/// A host TCP socket, plus associated bookkeeping.
86pub struct TcpSocket {
87    /// The current state in the bind/listen/accept/connect progression.
88    tcp_state: TcpState,
89
90    /// The desired listen queue size.
91    listen_backlog_size: u32,
92
93    family: SocketAddressFamily,
94
95    // The socket options below are not automatically inherited from the listener
96    // on all platforms. So we keep track of which options have been explicitly
97    // set and manually apply those values to newly accepted clients.
98    #[cfg(target_os = "macos")]
99    receive_buffer_size: Option<usize>,
100    #[cfg(target_os = "macos")]
101    send_buffer_size: Option<usize>,
102    #[cfg(target_os = "macos")]
103    hop_limit: Option<u8>,
104    #[cfg(target_os = "macos")]
105    keep_alive_idle_time: Option<std::time::Duration>,
106}
107
108impl TcpSocket {
109    /// Create a new socket in the given family.
110    pub fn new(family: AddressFamily) -> io::Result<Self> {
111        with_ambient_tokio_runtime(|| {
112            let (socket, family) = match family {
113                AddressFamily::Ipv4 => {
114                    let socket = tokio::net::TcpSocket::new_v4()?;
115                    (socket, SocketAddressFamily::Ipv4)
116                }
117                AddressFamily::Ipv6 => {
118                    let socket = tokio::net::TcpSocket::new_v6()?;
119                    sockopt::set_ipv6_v6only(&socket, true)?;
120                    (socket, SocketAddressFamily::Ipv6)
121                }
122            };
123
124            Self::from_state(TcpState::Default(socket), family)
125        })
126    }
127
128    /// Create a `TcpSocket` from an existing socket.
129    fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
130        Ok(Self {
131            tcp_state: state,
132            listen_backlog_size: DEFAULT_TCP_BACKLOG,
133            family,
134            #[cfg(target_os = "macos")]
135            receive_buffer_size: None,
136            #[cfg(target_os = "macos")]
137            send_buffer_size: None,
138            #[cfg(target_os = "macos")]
139            hop_limit: None,
140            #[cfg(target_os = "macos")]
141            keep_alive_idle_time: None,
142        })
143    }
144
145    fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
146        use crate::p2::bindings::sockets::network::ErrorCode;
147
148        match &self.tcp_state {
149            TcpState::Default(socket) | TcpState::Bound(socket) => {
150                Ok(socket.as_socketlike_view::<std::net::TcpStream>())
151            }
152            TcpState::Connected { stream, .. } => {
153                Ok(stream.as_socketlike_view::<std::net::TcpStream>())
154            }
155            TcpState::Listening { listener, .. } => {
156                Ok(listener.as_socketlike_view::<std::net::TcpStream>())
157            }
158
159            TcpState::BindStarted(..)
160            | TcpState::ListenStarted(..)
161            | TcpState::Connecting(..)
162            | TcpState::ConnectReady(..)
163            | TcpState::Closed => Err(ErrorCode::InvalidState.into()),
164        }
165    }
166}
167
168impl TcpSocket {
169    pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
170        let tokio_socket = match &self.tcp_state {
171            TcpState::Default(socket) => socket,
172            TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
173            _ => return Err(Errno::ISCONN.into()),
174        };
175
176        network::util::validate_unicast(&local_address)?;
177        network::util::validate_address_family(&local_address, &self.family)?;
178
179        {
180            // Automatically bypass the TIME_WAIT state when the user is trying
181            // to bind to a specific port:
182            let reuse_addr = local_address.port() > 0;
183
184            // Unconditionally (re)set SO_REUSEADDR, even when the value is false.
185            // This ensures we're not accidentally affected by any socket option
186            // state left behind by a previous failed call to this method (start_bind).
187            network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
188
189            // Perform the OS bind call.
190            tokio_socket.bind(local_address).map_err(|error| {
191                match Errno::from_io_error(&error) {
192                    // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
193                    // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
194                    //
195                    // The most common reasons for this error should have already
196                    // been handled by our own validation slightly higher up in this
197                    // function. This error mapping is here just in case there is
198                    // an edge case we didn't catch.
199                    Some(Errno::AFNOSUPPORT) =>  io::Error::new(
200                        io::ErrorKind::InvalidInput,
201                        "The specified address is not a valid address for the address family of the specified socket",
202                    ),
203
204                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
205                    // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
206                    #[cfg(windows)]
207                    Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
208
209                    _ => error,
210                }
211            })?;
212
213            self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
214                TcpState::Default(socket) => TcpState::BindStarted(socket),
215                _ => unreachable!(),
216            };
217
218            Ok(())
219        }
220    }
221
222    pub fn finish_bind(&mut self) -> SocketResult<()> {
223        match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
224            TcpState::BindStarted(socket) => {
225                self.tcp_state = TcpState::Bound(socket);
226                Ok(())
227            }
228            current_state => {
229                // Reset the state so that the outside world doesn't see this socket as closed
230                self.tcp_state = current_state;
231                Err(ErrorCode::NotInProgress.into())
232            }
233        }
234    }
235
236    pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
237        match self.tcp_state {
238            TcpState::Default(..) | TcpState::Bound(..) => {}
239
240            TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
241                return Err(ErrorCode::ConcurrencyConflict.into())
242            }
243
244            _ => return Err(ErrorCode::InvalidState.into()),
245        };
246
247        network::util::validate_unicast(&remote_address)?;
248        network::util::validate_remote_address(&remote_address)?;
249        network::util::validate_address_family(&remote_address, &self.family)?;
250
251        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
252            std::mem::replace(&mut self.tcp_state, TcpState::Closed)
253        else {
254            unreachable!();
255        };
256
257        let future = tokio_socket.connect(remote_address);
258
259        self.tcp_state = TcpState::Connecting(Box::pin(future));
260        Ok(())
261    }
262
263    pub fn finish_connect(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
264        let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
265        let result = match previous_state {
266            TcpState::ConnectReady(result) => result,
267            TcpState::Connecting(mut future) => {
268                let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
269                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
270                    Poll::Ready(result) => result,
271                    Poll::Pending => {
272                        self.tcp_state = TcpState::Connecting(future);
273                        return Err(ErrorCode::WouldBlock.into());
274                    }
275                }
276            }
277            previous_state => {
278                self.tcp_state = previous_state;
279                return Err(ErrorCode::NotInProgress.into());
280            }
281        };
282
283        match result {
284            Ok(stream) => {
285                let stream = Arc::new(stream);
286                let reader = Arc::new(Mutex::new(TcpReader::new(stream.clone())));
287                let writer = Arc::new(Mutex::new(TcpWriter::new(stream.clone())));
288                self.tcp_state = TcpState::Connected {
289                    stream,
290                    reader: reader.clone(),
291                    writer: writer.clone(),
292                };
293                let input: DynInputStream = Box::new(TcpReadStream(reader));
294                let output: DynOutputStream = Box::new(TcpWriteStream(writer));
295                Ok((input, output))
296            }
297            Err(err) => {
298                self.tcp_state = TcpState::Closed;
299                Err(err.into())
300            }
301        }
302    }
303
304    pub fn start_listen(&mut self) -> SocketResult<()> {
305        match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
306            TcpState::Bound(tokio_socket) => {
307                self.tcp_state = TcpState::ListenStarted(tokio_socket);
308                Ok(())
309            }
310            TcpState::ListenStarted(tokio_socket) => {
311                self.tcp_state = TcpState::ListenStarted(tokio_socket);
312                Err(ErrorCode::ConcurrencyConflict.into())
313            }
314            previous_state => {
315                self.tcp_state = previous_state;
316                Err(ErrorCode::InvalidState.into())
317            }
318        }
319    }
320
321    pub fn finish_listen(&mut self) -> SocketResult<()> {
322        let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
323            TcpState::ListenStarted(tokio_socket) => tokio_socket,
324            previous_state => {
325                self.tcp_state = previous_state;
326                return Err(ErrorCode::NotInProgress.into());
327            }
328        };
329
330        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
331            Ok(listener) => {
332                self.tcp_state = TcpState::Listening {
333                    listener,
334                    pending_accept: None,
335                };
336                Ok(())
337            }
338            Err(err) => {
339                self.tcp_state = TcpState::Closed;
340
341                Err(match Errno::from_io_error(&err) {
342                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
343                    // According to the docs, `listen` can return EMFILE on Windows.
344                    // This is odd, because we're not trying to create a new socket
345                    // or file descriptor of any kind. So we rewrite it to less
346                    // surprising error code.
347                    //
348                    // At the time of writing, this behavior has never been experimentally
349                    // observed by any of the wasmtime authors, so we're relying fully
350                    // on Microsoft's documentation here.
351                    #[cfg(windows)]
352                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
353
354                    _ => err.into(),
355                })
356            }
357        }
358    }
359
360    pub fn accept(&mut self) -> SocketResult<(Self, DynInputStream, DynOutputStream)> {
361        let TcpState::Listening {
362            listener,
363            pending_accept,
364        } = &mut self.tcp_state
365        else {
366            return Err(ErrorCode::InvalidState.into());
367        };
368
369        let result = match pending_accept.take() {
370            Some(result) => result,
371            None => {
372                let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
373                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
374                    .map_ok(|(stream, _)| stream)
375                {
376                    Poll::Ready(result) => result,
377                    Poll::Pending => Err(Errno::WOULDBLOCK.into()),
378                }
379            }
380        };
381
382        let client = result.map_err(|err| match Errno::from_io_error(&err) {
383            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
384            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
385            // > or the service provider is still processing a callback function.
386            //
387            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
388            // because in POSIX this error is only returned by a non-blocking
389            // `connect` and wasi-sockets has a different solution for that.
390            #[cfg(windows)]
391            Some(Errno::INPROGRESS) => Errno::INTR.into(),
392
393            // Normalize Linux' non-standard behavior.
394            //
395            // From https://man7.org/linux/man-pages/man2/accept.2.html:
396            // > Linux accept() passes already-pending network errors on the
397            // > new socket as an error code from accept(). This behavior
398            // > differs from other BSD socket implementations. (...)
399            #[cfg(target_os = "linux")]
400            Some(
401                Errno::CONNRESET
402                | Errno::NETRESET
403                | Errno::HOSTUNREACH
404                | Errno::HOSTDOWN
405                | Errno::NETDOWN
406                | Errno::NETUNREACH
407                | Errno::PROTO
408                | Errno::NOPROTOOPT
409                | Errno::NONET
410                | Errno::OPNOTSUPP,
411            ) => Errno::CONNABORTED.into(),
412
413            _ => err,
414        })?;
415
416        #[cfg(target_os = "macos")]
417        {
418            // Manually inherit socket options from listener. We only have to
419            // do this on platforms that don't already do this automatically
420            // and only if a specific value was explicitly set on the listener.
421
422            if let Some(size) = self.receive_buffer_size {
423                _ = network::util::set_socket_recv_buffer_size(&client, size); // Ignore potential error.
424            }
425
426            if let Some(size) = self.send_buffer_size {
427                _ = network::util::set_socket_send_buffer_size(&client, size); // Ignore potential error.
428            }
429
430            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
431            if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
432                _ = network::util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error.
433            }
434
435            if let Some(value) = self.keep_alive_idle_time {
436                _ = network::util::set_tcp_keepidle(&client, value); // Ignore potential error.
437            }
438        }
439
440        let client = Arc::new(client);
441
442        let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
443        let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
444
445        let input: DynInputStream = Box::new(TcpReadStream(reader.clone()));
446        let output: DynOutputStream = Box::new(TcpWriteStream(writer.clone()));
447        let tcp_socket = TcpSocket::from_state(
448            TcpState::Connected {
449                stream: client,
450                reader,
451                writer,
452            },
453            self.family,
454        )?;
455
456        Ok((tcp_socket, input, output))
457    }
458
459    pub fn local_address(&self) -> SocketResult<SocketAddr> {
460        let view = match self.tcp_state {
461            TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
462            TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
463            _ => self.as_std_view()?,
464        };
465
466        Ok(view.local_addr()?)
467    }
468
469    pub fn remote_address(&self) -> SocketResult<SocketAddr> {
470        let view = match self.tcp_state {
471            TcpState::Connected { .. } => self.as_std_view()?,
472            TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
473                return Err(ErrorCode::ConcurrencyConflict.into())
474            }
475            _ => return Err(ErrorCode::InvalidState.into()),
476        };
477
478        Ok(view.peer_addr()?)
479    }
480
481    pub fn is_listening(&self) -> bool {
482        matches!(self.tcp_state, TcpState::Listening { .. })
483    }
484
485    pub fn address_family(&self) -> SocketAddressFamily {
486        self.family
487    }
488
489    pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
490        const MIN_BACKLOG: u32 = 1;
491        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
492
493        if value == 0 {
494            return Err(ErrorCode::InvalidArgument.into());
495        }
496
497        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
498        let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
499
500        match &self.tcp_state {
501            TcpState::Default(..) | TcpState::Bound(..) => {
502                // Socket not listening yet. Stash value for first invocation to `listen`.
503            }
504            TcpState::Listening { listener, .. } => {
505                // Try to update the backlog by calling `listen` again.
506                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
507
508                rustix::net::listen(&listener, value.try_into().unwrap())
509                    .map_err(|_| ErrorCode::NotSupported)?;
510            }
511            _ => return Err(ErrorCode::InvalidState.into()),
512        }
513        self.listen_backlog_size = value;
514
515        Ok(())
516    }
517
518    pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
519        let view = &*self.as_std_view()?;
520        Ok(sockopt::socket_keepalive(view)?)
521    }
522
523    pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
524        let view = &*self.as_std_view()?;
525        Ok(sockopt::set_socket_keepalive(view, value)?)
526    }
527
528    pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
529        let view = &*self.as_std_view()?;
530        Ok(sockopt::tcp_keepidle(view)?)
531    }
532
533    pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
534        {
535            let view = &*self.as_std_view()?;
536            network::util::set_tcp_keepidle(view, duration)?;
537        }
538
539        #[cfg(target_os = "macos")]
540        {
541            self.keep_alive_idle_time = Some(duration);
542        }
543
544        Ok(())
545    }
546
547    pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
548        let view = &*self.as_std_view()?;
549        Ok(sockopt::tcp_keepintvl(view)?)
550    }
551
552    pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
553        let view = &*self.as_std_view()?;
554        Ok(network::util::set_tcp_keepintvl(view, duration)?)
555    }
556
557    pub fn keep_alive_count(&self) -> SocketResult<u32> {
558        let view = &*self.as_std_view()?;
559        Ok(sockopt::tcp_keepcnt(view)?)
560    }
561
562    pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
563        let view = &*self.as_std_view()?;
564        Ok(network::util::set_tcp_keepcnt(view, value)?)
565    }
566
567    pub fn hop_limit(&self) -> SocketResult<u8> {
568        let view = &*self.as_std_view()?;
569
570        let ttl = match self.family {
571            SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
572            SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
573        };
574
575        Ok(ttl)
576    }
577
578    pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
579        {
580            let view = &*self.as_std_view()?;
581
582            match self.family {
583                SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
584                SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
585            }
586        }
587
588        #[cfg(target_os = "macos")]
589        {
590            self.hop_limit = Some(value);
591        }
592
593        Ok(())
594    }
595
596    pub fn receive_buffer_size(&self) -> SocketResult<usize> {
597        let view = &*self.as_std_view()?;
598
599        Ok(network::util::get_socket_recv_buffer_size(view)?)
600    }
601
602    pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
603        {
604            let view = &*self.as_std_view()?;
605
606            network::util::set_socket_recv_buffer_size(view, value)?;
607        }
608
609        #[cfg(target_os = "macos")]
610        {
611            self.receive_buffer_size = Some(value);
612        }
613
614        Ok(())
615    }
616
617    pub fn send_buffer_size(&self) -> SocketResult<usize> {
618        let view = &*self.as_std_view()?;
619
620        Ok(network::util::get_socket_send_buffer_size(view)?)
621    }
622
623    pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
624        {
625            let view = &*self.as_std_view()?;
626
627            network::util::set_socket_send_buffer_size(view, value)?;
628        }
629
630        #[cfg(target_os = "macos")]
631        {
632            self.send_buffer_size = Some(value);
633        }
634
635        Ok(())
636    }
637
638    pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
639        let TcpState::Connected { reader, writer, .. } = &self.tcp_state else {
640            return Err(ErrorCode::InvalidState.into());
641        };
642
643        if let Shutdown::Both | Shutdown::Read = how {
644            try_lock_for_socket(reader)?.shutdown();
645        }
646
647        if let Shutdown::Both | Shutdown::Write = how {
648            try_lock_for_socket(writer)?.shutdown();
649        }
650
651        Ok(())
652    }
653}
654
655#[async_trait::async_trait]
656impl Pollable for TcpSocket {
657    async fn ready(&mut self) {
658        match &mut self.tcp_state {
659            TcpState::Default(..)
660            | TcpState::BindStarted(..)
661            | TcpState::Bound(..)
662            | TcpState::ListenStarted(..)
663            | TcpState::ConnectReady(..)
664            | TcpState::Closed
665            | TcpState::Connected { .. } => {
666                // No async operation in progress.
667            }
668            TcpState::Connecting(future) => {
669                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
670            }
671            TcpState::Listening {
672                listener,
673                pending_accept,
674            } => match pending_accept {
675                Some(_) => {}
676                None => {
677                    let result = futures::future::poll_fn(|cx| {
678                        listener.poll_accept(cx).map_ok(|(stream, _)| stream)
679                    })
680                    .await;
681                    *pending_accept = Some(result);
682                }
683            },
684        }
685    }
686}
687
688struct TcpReader {
689    stream: Arc<tokio::net::TcpStream>,
690    closed: bool,
691}
692
693impl TcpReader {
694    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
695        Self {
696            stream,
697            closed: false,
698        }
699    }
700    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
701        if self.closed {
702            return Err(StreamError::Closed);
703        }
704        if size == 0 {
705            return Ok(bytes::Bytes::new());
706        }
707
708        let mut buf = bytes::BytesMut::with_capacity(size);
709        let n = match self.stream.try_read_buf(&mut buf) {
710            // A 0-byte read indicates that the stream has closed.
711            Ok(0) => {
712                self.closed = true;
713                return Err(StreamError::Closed);
714            }
715            Ok(n) => n,
716
717            // Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no
718            // data to read right now.
719            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
720
721            Err(e) => {
722                self.closed = true;
723                return Err(StreamError::LastOperationFailed(e.into()));
724            }
725        };
726
727        buf.truncate(n);
728        Ok(buf.freeze())
729    }
730
731    fn shutdown(&mut self) {
732        native_shutdown(&self.stream, Shutdown::Read);
733        self.closed = true;
734    }
735
736    async fn ready(&mut self) {
737        if self.closed {
738            return;
739        }
740
741        self.stream.readable().await.unwrap();
742    }
743}
744
745struct TcpReadStream(Arc<Mutex<TcpReader>>);
746
747#[async_trait::async_trait]
748impl InputStream for TcpReadStream {
749    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
750        try_lock_for_stream(&self.0)?.read(size)
751    }
752}
753
754#[async_trait::async_trait]
755impl Pollable for TcpReadStream {
756    async fn ready(&mut self) {
757        self.0.lock().await.ready().await
758    }
759}
760
761const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
762
763struct TcpWriter {
764    stream: Arc<tokio::net::TcpStream>,
765    state: WriteState,
766}
767
768enum WriteState {
769    Ready,
770    Writing(AbortOnDropJoinHandle<io::Result<()>>),
771    Closing(AbortOnDropJoinHandle<io::Result<()>>),
772    Closed,
773    Error(io::Error),
774}
775
776impl TcpWriter {
777    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
778        Self {
779            stream,
780            state: WriteState::Ready,
781        }
782    }
783
784    fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
785        stream.try_write(buf).map_err(|error| {
786            match Errno::from_io_error(&error) {
787                // Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
788                // We normalize this to EPIPE, because that is what the other platforms return.
789                // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
790                #[cfg(windows)]
791                Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
792
793                _ => error,
794            }
795        })
796    }
797
798    /// Write `bytes` in a background task, remembering the task handle for use in a future call to
799    /// `write_ready`
800    fn background_write(&mut self, mut bytes: bytes::Bytes) {
801        assert!(matches!(self.state, WriteState::Ready));
802
803        let stream = self.stream.clone();
804        self.state = WriteState::Writing(crate::runtime::spawn(async move {
805            // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream
806            // primitive try_write, which goes directly to attempt a write with mio. This has
807            // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream
808            // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need
809            // to flush.
810            while !bytes.is_empty() {
811                stream.writable().await?;
812                match Self::try_write_portable(&stream, &bytes) {
813                    Ok(n) => {
814                        let _ = bytes.split_to(n);
815                    }
816                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
817                    Err(e) => return Err(e.into()),
818                }
819            }
820
821            Ok(())
822        }));
823    }
824
825    fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
826        match self.state {
827            WriteState::Ready => {}
828            WriteState::Closed => return Err(StreamError::Closed),
829            WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
830                return Err(StreamError::Trap(anyhow::anyhow!(
831                    "unpermitted: must call check_write first"
832                )));
833            }
834        }
835        while !bytes.is_empty() {
836            match Self::try_write_portable(&self.stream, &bytes) {
837                Ok(n) => {
838                    let _ = bytes.split_to(n);
839                }
840
841                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
842                    // As `try_write` indicated that it would have blocked, we'll perform the write
843                    // in the background to allow us to return immediately.
844                    self.background_write(bytes);
845
846                    return Ok(());
847                }
848
849                Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
850                    self.state = WriteState::Closed;
851                    return Err(StreamError::Closed);
852                }
853
854                Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
855            }
856        }
857
858        Ok(())
859    }
860
861    fn flush(&mut self) -> Result<(), StreamError> {
862        // `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
863        // `write_ready` will join the background write task if it's active, so following `flush`
864        // with `write_ready` will have the desired effect.
865        match self.state {
866            WriteState::Ready
867            | WriteState::Writing(_)
868            | WriteState::Closing(_)
869            | WriteState::Error(_) => Ok(()),
870            WriteState::Closed => Err(StreamError::Closed),
871        }
872    }
873
874    fn check_write(&mut self) -> Result<usize, StreamError> {
875        match mem::replace(&mut self.state, WriteState::Closed) {
876            WriteState::Writing(task) => {
877                self.state = WriteState::Writing(task);
878                return Ok(0);
879            }
880            WriteState::Closing(task) => {
881                self.state = WriteState::Closing(task);
882                return Ok(0);
883            }
884            WriteState::Ready => {
885                self.state = WriteState::Ready;
886            }
887            WriteState::Closed => return Err(StreamError::Closed),
888            WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
889        }
890
891        let writable = self.stream.writable();
892        futures::pin_mut!(writable);
893        if crate::runtime::poll_noop(writable).is_none() {
894            return Ok(0);
895        }
896        Ok(SOCKET_READY_SIZE)
897    }
898
899    fn shutdown(&mut self) {
900        self.state = match mem::replace(&mut self.state, WriteState::Closed) {
901            // No write in progress, immediately shut down:
902            WriteState::Ready => {
903                native_shutdown(&self.stream, Shutdown::Write);
904                WriteState::Closed
905            }
906
907            // Schedule the shutdown after the current write has finished:
908            WriteState::Writing(write) => {
909                let stream = self.stream.clone();
910                WriteState::Closing(crate::runtime::spawn(async move {
911                    let result = write.await;
912                    native_shutdown(&stream, Shutdown::Write);
913                    result
914                }))
915            }
916
917            s => s,
918        };
919    }
920
921    async fn cancel(&mut self) {
922        match mem::replace(&mut self.state, WriteState::Closed) {
923            WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
924            _ => {}
925        }
926    }
927
928    async fn ready(&mut self) {
929        match &mut self.state {
930            WriteState::Writing(task) => {
931                self.state = match task.await {
932                    Ok(()) => WriteState::Ready,
933                    Err(e) => WriteState::Error(e),
934                }
935            }
936            WriteState::Closing(task) => {
937                self.state = match task.await {
938                    Ok(()) => WriteState::Closed,
939                    Err(e) => WriteState::Error(e),
940                }
941            }
942            _ => {}
943        }
944
945        if let WriteState::Ready = self.state {
946            self.stream.writable().await.unwrap();
947        }
948    }
949}
950
951struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
952
953#[async_trait::async_trait]
954impl OutputStream for TcpWriteStream {
955    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
956        try_lock_for_stream(&self.0)?.write(bytes)
957    }
958
959    fn flush(&mut self) -> Result<(), StreamError> {
960        try_lock_for_stream(&self.0)?.flush()
961    }
962
963    fn check_write(&mut self) -> Result<usize, StreamError> {
964        try_lock_for_stream(&self.0)?.check_write()
965    }
966
967    async fn cancel(&mut self) {
968        self.0.lock().await.cancel().await
969    }
970}
971
972#[async_trait::async_trait]
973impl Pollable for TcpWriteStream {
974    async fn ready(&mut self) {
975        self.0.lock().await.ready().await
976    }
977}
978
979fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
980    _ = stream
981        .as_socketlike_view::<std::net::TcpStream>()
982        .shutdown(how);
983}
984
985fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
986    mutex
987        .try_lock()
988        .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
989}
990
991fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, SocketError> {
992    mutex.try_lock().map_err(|_| {
993        SocketError::trap(anyhow::anyhow!(
994            "concurrent access to resource not supported"
995        ))
996    })
997}