Skip to main content

wasmtime_wasi/sockets/
tcp.rs

1use crate::p2::P2TcpStreamingState;
2use crate::runtime::with_ambient_tokio_runtime;
3use crate::sockets::util::{
4    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
5    is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
6    set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
7    set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
8};
9use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
10use io_lifetimes::AsSocketlike as _;
11use io_lifetimes::views::SocketlikeView;
12use rustix::io::Errno;
13use rustix::net::sockopt;
14use std::fmt::Debug;
15use std::io;
16use std::mem;
17use std::net::SocketAddr;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll, Waker};
21use std::time::Duration;
22
23/// The state of a TCP socket.
24///
25/// This represents the various states a socket can be in during the
26/// activities of binding, listening, accepting, and connecting. Note that this
27/// state machine encompasses both WASIp2 and WASIp3.
28enum TcpState {
29    /// The initial state for a newly-created socket.
30    ///
31    /// From here a socket can transition to `BindStarted`, `ListenStarted`, or
32    /// `Connecting`.
33    Default(tokio::net::TcpSocket),
34
35    /// A state indicating that a bind has been started and must be finished
36    /// subsequently with `finish_bind`.
37    ///
38    /// From here a socket can transition to `Bound`.
39    BindStarted(tokio::net::TcpSocket),
40
41    /// Binding finished. The socket has an address but is not yet listening for
42    /// connections.
43    ///
44    /// From here a socket can transition to `ListenStarted`, or `Connecting`.
45    Bound(tokio::net::TcpSocket),
46
47    /// Listening on a socket has started and must be completed with
48    /// `finish_listen`.
49    ///
50    /// From here a socket can transition to `Listening`.
51    ListenStarted(tokio::net::TcpSocket),
52
53    /// The socket is now listening and waiting for an incoming connection.
54    ///
55    /// Sockets will not leave this state.
56    Listening {
57        /// The raw tokio-basd TCP listener managing the underlying socket.
58        listener: Arc<tokio::net::TcpListener>,
59
60        /// The last-accepted connection, set during the `ready` method and read
61        /// during the `accept` method. Note that this is only used for WASIp2
62        /// at this time.
63        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
64    },
65
66    /// An outgoing connection is started.
67    ///
68    /// This is created via the `start_connect` method. The payload here is an
69    /// optionally-specified owned future for the result of the connect. In
70    /// WASIp2 the future lives here, but in WASIp3 it lives on the event loop
71    /// so this is `None`.
72    ///
73    /// From here a socket can transition to `ConnectReady` or `Connected`.
74    Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
75
76    /// A connection via `Connecting` has completed.
77    ///
78    /// This is present for WASIp2 where the `Connecting` state stores `Some` of
79    /// a future, and the result of that future is recorded here when it
80    /// finishes as part of the `ready` method.
81    ///
82    /// From here a socket can transition to `Connected`.
83    ConnectReady(io::Result<tokio::net::TcpStream>),
84
85    /// A connection has been established.
86    ///
87    /// This is created either via `finish_connect` or for freshly accepted
88    /// sockets from a TCP listener.
89    ///
90    /// A socket will not transition out of this state.
91    Connected {
92        stream: Arc<tokio::net::TcpStream>,
93        taken_streams: TakenStreams,
94        p2_state: Option<P2TcpStreamingState>,
95    },
96
97    /// This is not actually a socket but a deferred error.
98    ///
99    /// This error came out of `accept` and is deferred until the socket is
100    /// operated on.
101    #[cfg(feature = "p3")]
102    Error(io::Error),
103
104    /// The socket is closed and no more operations can be performed.
105    Closed,
106}
107impl TcpState {
108    fn connected(stream: tokio::net::TcpStream) -> Self {
109        TcpState::Connected {
110            stream: Arc::new(stream),
111            taken_streams: TakenStreams {
112                receive: false,
113                send: false,
114            },
115            p2_state: None,
116        }
117    }
118}
119impl Debug for TcpState {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Self::Default(_) => f.debug_tuple("Default").finish(),
123            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
124            Self::Bound(_) => f.debug_tuple("Bound").finish(),
125            Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
126            Self::Listening { .. } => f.debug_tuple("Listening").finish(),
127            Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
128            Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
129            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
130            #[cfg(feature = "p3")]
131            Self::Error(..) => f.debug_tuple("Error").finish(),
132            Self::Closed => write!(f, "Closed"),
133        }
134    }
135}
136
137struct TakenStreams {
138    receive: bool,
139    send: bool,
140}
141
142/// A host TCP socket, plus associated bookkeeping.
143pub struct TcpSocket {
144    /// The current state in the bind/listen/accept/connect progression.
145    tcp_state: TcpState,
146
147    /// The desired listen queue size.
148    listen_backlog_size: u32,
149
150    family: SocketAddressFamily,
151
152    options: NonInheritedOptions,
153}
154
155impl TcpSocket {
156    /// Create a new socket in the given family.
157    pub(crate) fn new(
158        ctx: &WasiSocketsCtx,
159        family: SocketAddressFamily,
160    ) -> Result<Self, ErrorCode> {
161        ctx.allowed_network_uses.check_allowed_tcp()?;
162
163        with_ambient_tokio_runtime(|| {
164            let socket = match family {
165                SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
166                SocketAddressFamily::Ipv6 => {
167                    let socket = tokio::net::TcpSocket::new_v6()?;
168                    sockopt::set_ipv6_v6only(&socket, true)?;
169                    socket
170                }
171            };
172
173            Ok(Self::from_state(TcpState::Default(socket), family))
174        })
175    }
176
177    #[cfg(feature = "p3")]
178    pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
179        TcpSocket::from_state(TcpState::Error(err), family)
180    }
181
182    /// Creates a new socket with the `result` of an accepted socket from a
183    /// `TcpListener`.
184    ///
185    /// This will handle the `result` internally and `result` should be the raw
186    /// result from a TCP listen operation.
187    pub(crate) fn new_accept(
188        result: io::Result<tokio::net::TcpStream>,
189        options: &NonInheritedOptions,
190        family: SocketAddressFamily,
191    ) -> io::Result<Self> {
192        let client = result.map_err(|err| match Errno::from_io_error(&err) {
193            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
194            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
195            // > or the service provider is still processing a callback function.
196            //
197            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
198            // because in POSIX this error is only returned by a non-blocking
199            // `connect` and wasi-sockets has a different solution for that.
200            #[cfg(windows)]
201            Some(Errno::INPROGRESS) => Errno::INTR.into(),
202
203            // Normalize Linux' non-standard behavior.
204            //
205            // From https://man7.org/linux/man-pages/man2/accept.2.html:
206            // > Linux accept() passes already-pending network errors on the
207            // > new socket as an error code from accept(). This behavior
208            // > differs from other BSD socket implementations. (...)
209            #[cfg(target_os = "linux")]
210            Some(
211                Errno::CONNRESET
212                | Errno::NETRESET
213                | Errno::HOSTUNREACH
214                | Errno::HOSTDOWN
215                | Errno::NETDOWN
216                | Errno::NETUNREACH
217                | Errno::PROTO
218                | Errno::NOPROTOOPT
219                | Errno::NONET
220                | Errno::OPNOTSUPP,
221            ) => Errno::CONNABORTED.into(),
222
223            _ => err,
224        })?;
225        options.apply(family, &client);
226        Ok(Self::from_state(TcpState::connected(client), family))
227    }
228
229    /// Create a `TcpSocket` from an existing socket.
230    fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
231        Self {
232            tcp_state: state,
233            listen_backlog_size: DEFAULT_TCP_BACKLOG,
234            family,
235            options: Default::default(),
236        }
237    }
238
239    pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
240        match &self.tcp_state {
241            TcpState::Default(socket)
242            | TcpState::BindStarted(socket)
243            | TcpState::Bound(socket)
244            | TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
245            TcpState::Connected { stream, .. } => Ok(stream.as_socketlike_view()),
246            TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
247            TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
248                Err(ErrorCode::InvalidState)
249            }
250            #[cfg(feature = "p3")]
251            TcpState::Error(err) => Err(err.into()),
252        }
253    }
254
255    pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
256        let ip = addr.ip();
257        if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
258            return Err(ErrorCode::InvalidArgument);
259        }
260        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
261            TcpState::Default(sock) => {
262                if let Err(err) = tcp_bind(&sock, addr) {
263                    self.tcp_state = TcpState::Default(sock);
264                    Err(err)
265                } else {
266                    self.tcp_state = TcpState::BindStarted(sock);
267                    Ok(())
268                }
269            }
270            tcp_state => {
271                self.tcp_state = tcp_state;
272                Err(ErrorCode::InvalidState)
273            }
274        }
275    }
276
277    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
278        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
279            TcpState::BindStarted(socket) => {
280                self.tcp_state = TcpState::Bound(socket);
281                Ok(())
282            }
283            current_state => {
284                // Reset the state so that the outside world doesn't see this socket as closed
285                self.tcp_state = current_state;
286                Err(ErrorCode::NotInProgress)
287            }
288        }
289    }
290
291    pub(crate) fn start_connect(
292        &mut self,
293        addr: &SocketAddr,
294    ) -> Result<tokio::net::TcpSocket, ErrorCode> {
295        match self.tcp_state {
296            TcpState::Default(..) | TcpState::Bound(..) => {}
297            TcpState::Connecting(..) => {
298                return Err(ErrorCode::ConcurrencyConflict);
299            }
300            _ => return Err(ErrorCode::InvalidState),
301        };
302
303        if !is_valid_unicast_address(addr.ip())
304            || !is_valid_remote_address(*addr)
305            || !is_valid_address_family(addr.ip(), self.family)
306        {
307            return Err(ErrorCode::InvalidArgument);
308        };
309
310        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
311            mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
312        else {
313            unreachable!();
314        };
315
316        Ok(tokio_socket)
317    }
318
319    /// For WASIp2 this is used to record the actual connection future as part
320    /// of `start_connect` within this socket state.
321    pub(crate) fn set_pending_connect(
322        &mut self,
323        future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
324    ) -> Result<(), ErrorCode> {
325        match &mut self.tcp_state {
326            TcpState::Connecting(slot @ None) => {
327                *slot = Some(Box::pin(future));
328                Ok(())
329            }
330            _ => Err(ErrorCode::InvalidState),
331        }
332    }
333
334    /// For WASIp2 this retrieves the result from the future passed to
335    /// `set_pending_connect`.
336    ///
337    /// Return states here are:
338    ///
339    /// * `Ok(Some(res))` - where `res` is the result of the connect operation.
340    /// * `Ok(None)` - the connect operation isn't ready yet.
341    /// * `Err(e)` - a connect operation is not in progress.
342    pub(crate) fn take_pending_connect(
343        &mut self,
344    ) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
345        match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
346            TcpState::ConnectReady(result) => Ok(Some(result)),
347            TcpState::Connecting(Some(mut future)) => {
348                let mut cx = Context::from_waker(Waker::noop());
349                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
350                    Poll::Ready(result) => Ok(Some(result)),
351                    Poll::Pending => {
352                        self.tcp_state = TcpState::Connecting(Some(future));
353                        Ok(None)
354                    }
355                }
356            }
357            current_state => {
358                self.tcp_state = current_state;
359                Err(ErrorCode::NotInProgress)
360            }
361        }
362    }
363
364    pub(crate) fn finish_connect(
365        &mut self,
366        result: io::Result<tokio::net::TcpStream>,
367    ) -> Result<(), ErrorCode> {
368        if !matches!(self.tcp_state, TcpState::Connecting(None)) {
369            return Err(ErrorCode::InvalidState);
370        }
371        match result {
372            Ok(stream) => {
373                self.tcp_state = TcpState::connected(stream);
374                Ok(())
375            }
376            Err(err) => {
377                self.tcp_state = TcpState::Closed;
378                Err(ErrorCode::from(err))
379            }
380        }
381    }
382
383    /// Start listening using p2 semantics. (no implicit bind)
384    pub(crate) fn start_listen(&mut self) -> Result<(), ErrorCode> {
385        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
386            TcpState::Bound(tokio_socket) => {
387                self.tcp_state = TcpState::ListenStarted(tokio_socket);
388                Ok(())
389            }
390            previous_state => {
391                self.tcp_state = previous_state;
392                Err(ErrorCode::InvalidState)
393            }
394        }
395    }
396
397    pub(crate) fn finish_listen(&mut self) -> Result<(), ErrorCode> {
398        let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
399            TcpState::ListenStarted(tokio_socket) => tokio_socket,
400            previous_state => {
401                self.tcp_state = previous_state;
402                return Err(ErrorCode::NotInProgress);
403            }
404        };
405
406        self.listen_common(tokio_socket)
407    }
408
409    /// Returns whether this socket is in the bound state.
410    #[cfg(feature = "p3")]
411    pub(crate) fn is_bound(&self) -> bool {
412        match &self.tcp_state {
413            TcpState::Bound(_) => true,
414            _ => false,
415        }
416    }
417
418    fn listen_common(&mut self, tokio_socket: tokio::net::TcpSocket) -> Result<(), ErrorCode> {
419        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
420            Ok(listener) => {
421                self.tcp_state = TcpState::Listening {
422                    listener: Arc::new(listener),
423                    pending_accept: None,
424                };
425                Ok(())
426            }
427            Err(err) => {
428                self.tcp_state = TcpState::Closed;
429
430                Err(match Errno::from_io_error(&err) {
431                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
432                    // According to the docs, `listen` can return EMFILE on Windows.
433                    // This is odd, because we're not trying to create a new socket
434                    // or file descriptor of any kind. So we rewrite it to less
435                    // surprising error code.
436                    //
437                    // At the time of writing, this behavior has never been experimentally
438                    // observed by any of the wasmtime authors, so we're relying fully
439                    // on Microsoft's documentation here.
440                    #[cfg(windows)]
441                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
442
443                    _ => err.into(),
444                })
445            }
446        }
447    }
448
449    pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
450        let TcpState::Listening {
451            listener,
452            pending_accept,
453        } = &mut self.tcp_state
454        else {
455            return Err(ErrorCode::InvalidState);
456        };
457
458        let result = match pending_accept.take() {
459            Some(result) => result,
460            None => {
461                let mut cx = std::task::Context::from_waker(Waker::noop());
462                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
463                    .map_ok(|(stream, _)| stream)
464                {
465                    Poll::Ready(result) => result,
466                    Poll::Pending => return Ok(None),
467                }
468            }
469        };
470
471        Ok(Some(Self::new_accept(result, &self.options, self.family)?))
472    }
473
474    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
475        match &self.tcp_state {
476            TcpState::Bound(socket) => Ok(socket.local_addr()?),
477            TcpState::Connected { stream, .. } => Ok(stream.local_addr()?),
478            TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
479            #[cfg(feature = "p3")]
480            TcpState::Error(err) => Err(err.into()),
481            _ => Err(ErrorCode::InvalidState),
482        }
483    }
484
485    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
486        match &self.tcp_state {
487            TcpState::Connected { stream, .. } => Ok(stream.peer_addr()?),
488            #[cfg(feature = "p3")]
489            TcpState::Error(err) => Err(err.into()),
490            _ => Err(ErrorCode::InvalidState),
491        }
492    }
493
494    pub(crate) fn is_listening(&self) -> bool {
495        matches!(self.tcp_state, TcpState::Listening { .. })
496    }
497
498    pub(crate) fn address_family(&self) -> SocketAddressFamily {
499        self.family
500    }
501
502    pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
503        const MIN_BACKLOG: u32 = 1;
504        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
505
506        if value == 0 {
507            return Err(ErrorCode::InvalidArgument);
508        }
509        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
510        let value = value
511            .try_into()
512            .unwrap_or(MAX_BACKLOG)
513            .clamp(MIN_BACKLOG, MAX_BACKLOG);
514        match &self.tcp_state {
515            TcpState::Default(..) | TcpState::Bound(..) => {
516                // Socket not listening yet. Stash value for first invocation to `listen`.
517                self.listen_backlog_size = value;
518                Ok(())
519            }
520            TcpState::Listening { listener, .. } => {
521                // Try to update the backlog by calling `listen` again.
522                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
523                if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
524                    return Err(ErrorCode::NotSupported);
525                }
526                self.listen_backlog_size = value;
527                Ok(())
528            }
529            #[cfg(feature = "p3")]
530            TcpState::Error(err) => Err(err.into()),
531            _ => Err(ErrorCode::InvalidState),
532        }
533    }
534
535    pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
536        let fd = &*self.as_std_view()?;
537        let v = sockopt::socket_keepalive(fd)?;
538        Ok(v)
539    }
540
541    pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
542        let fd = &*self.as_std_view()?;
543        sockopt::set_socket_keepalive(fd, value)?;
544        Ok(())
545    }
546
547    pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
548        let fd = &*self.as_std_view()?;
549        let v = sockopt::tcp_keepidle(fd)?;
550        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
551    }
552
553    pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
554        let value = {
555            let fd = self.as_std_view()?;
556            set_keep_alive_idle_time(&*fd, value)?
557        };
558        self.options.set_keep_alive_idle_time(value);
559        Ok(())
560    }
561
562    pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
563        let fd = &*self.as_std_view()?;
564        let v = sockopt::tcp_keepintvl(fd)?;
565        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
566    }
567
568    pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
569        let fd = &*self.as_std_view()?;
570        set_keep_alive_interval(fd, Duration::from_nanos(value))?;
571        Ok(())
572    }
573
574    pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
575        let fd = &*self.as_std_view()?;
576        let v = sockopt::tcp_keepcnt(fd)?;
577        Ok(v)
578    }
579
580    pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
581        let fd = &*self.as_std_view()?;
582        set_keep_alive_count(fd, value)?;
583        Ok(())
584    }
585
586    pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
587        let fd = &*self.as_std_view()?;
588        let n = get_unicast_hop_limit(fd, self.family)?;
589        Ok(n)
590    }
591
592    pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
593        {
594            let fd = &*self.as_std_view()?;
595            set_unicast_hop_limit(fd, self.family, value)?;
596        }
597        self.options.set_hop_limit(value);
598        Ok(())
599    }
600
601    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
602        let fd = &*self.as_std_view()?;
603        let n = receive_buffer_size(fd)?;
604        Ok(n)
605    }
606
607    pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
608        let res = {
609            let fd = &*self.as_std_view()?;
610            set_receive_buffer_size(fd, value)?
611        };
612        self.options.set_receive_buffer_size(res);
613        Ok(())
614    }
615
616    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
617        let fd = &*self.as_std_view()?;
618        let n = send_buffer_size(fd)?;
619        Ok(n)
620    }
621
622    pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
623        let res = {
624            let fd = &*self.as_std_view()?;
625            set_send_buffer_size(fd, value)?
626        };
627        self.options.set_send_buffer_size(res);
628        Ok(())
629    }
630
631    #[cfg(feature = "p3")]
632    pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
633        &self.options
634    }
635
636    #[cfg(feature = "p3")]
637    pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
638        match &self.tcp_state {
639            TcpState::Listening { listener, .. } => Ok(listener),
640            #[cfg(feature = "p3")]
641            TcpState::Error(err) => Err(err.into()),
642            _ => Err(ErrorCode::InvalidState),
643        }
644    }
645
646    pub(crate) fn take_receive_stream(&mut self) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
647        self.take_stream(|s| &mut s.receive)
648    }
649
650    pub(crate) fn take_send_stream(&mut self) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
651        self.take_stream(|s| &mut s.send)
652    }
653
654    fn take_stream(
655        &mut self,
656        direction: impl FnOnce(&mut TakenStreams) -> &mut bool,
657    ) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
658        match &mut self.tcp_state {
659            TcpState::Connected {
660                stream,
661                taken_streams,
662                ..
663            } => {
664                let taken = direction(taken_streams);
665                if *taken {
666                    return Err(ErrorCode::InvalidState);
667                }
668                *taken = true;
669                Ok(stream.clone())
670            }
671            #[cfg(feature = "p3")]
672            TcpState::Error(err) => Err((&*err).into()),
673            _ => Err(ErrorCode::InvalidState),
674        }
675    }
676
677    pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
678        match &self.tcp_state {
679            TcpState::Connected {
680                p2_state: Some(state),
681                ..
682            } => Ok(state),
683            #[cfg(feature = "p3")]
684            TcpState::Error(err) => Err(err.into()),
685            _ => Err(ErrorCode::InvalidState),
686        }
687    }
688
689    pub(crate) fn set_p2_streaming_state(
690        &mut self,
691        state: P2TcpStreamingState,
692    ) -> Result<(), ErrorCode> {
693        if let TcpState::Connected { p2_state, .. } = &mut self.tcp_state {
694            *p2_state = Some(state);
695            Ok(())
696        } else {
697            Err(ErrorCode::InvalidState)
698        }
699    }
700
701    /// Used for `Pollable` in the WASIp2 implementation this awaits the socket
702    /// to be connected, if in the connecting state, or for a TCP accept to be
703    /// ready, if this is in the listening state.
704    ///
705    /// For all other states this method immediately returns.
706    pub(crate) async fn ready(&mut self) {
707        match &mut self.tcp_state {
708            TcpState::Default(..)
709            | TcpState::BindStarted(..)
710            | TcpState::Bound(..)
711            | TcpState::ListenStarted(..)
712            | TcpState::ConnectReady(..)
713            | TcpState::Closed
714            | TcpState::Connected { .. }
715            | TcpState::Connecting(None)
716            | TcpState::Listening {
717                pending_accept: Some(_),
718                ..
719            } => {}
720
721            #[cfg(feature = "p3")]
722            TcpState::Error(_) => {}
723
724            TcpState::Connecting(Some(future)) => {
725                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
726            }
727
728            TcpState::Listening {
729                listener,
730                pending_accept: slot @ None,
731            } => {
732                let result = futures::future::poll_fn(|cx| {
733                    listener.poll_accept(cx).map_ok(|(stream, _)| stream)
734                })
735                .await;
736                *slot = Some(result);
737            }
738        }
739    }
740}
741
742#[cfg(not(target_os = "macos"))]
743pub use inherits_option::*;
744#[cfg(not(target_os = "macos"))]
745mod inherits_option {
746    use crate::sockets::SocketAddressFamily;
747    use tokio::net::TcpStream;
748
749    #[derive(Default, Clone)]
750    pub struct NonInheritedOptions;
751
752    impl NonInheritedOptions {
753        pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
754
755        pub fn set_hop_limit(&mut self, _value: u8) {}
756
757        pub fn set_receive_buffer_size(&mut self, _value: usize) {}
758
759        pub fn set_send_buffer_size(&mut self, _value: usize) {}
760
761        pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
762    }
763}
764
765#[cfg(target_os = "macos")]
766pub use does_not_inherit_options::*;
767#[cfg(target_os = "macos")]
768mod does_not_inherit_options {
769    use crate::sockets::SocketAddressFamily;
770    use rustix::net::sockopt;
771    use std::sync::Arc;
772    use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
773    use std::time::Duration;
774    use tokio::net::TcpStream;
775
776    // The socket options below are not automatically inherited from the listener
777    // on all platforms. So we keep track of which options have been explicitly
778    // set and manually apply those values to newly accepted clients.
779    #[derive(Default, Clone)]
780    pub struct NonInheritedOptions(Arc<Inner>);
781
782    #[derive(Default)]
783    struct Inner {
784        receive_buffer_size: AtomicUsize,
785        send_buffer_size: AtomicUsize,
786        hop_limit: AtomicU8,
787        keep_alive_idle_time: AtomicU64, // nanoseconds
788    }
789
790    impl NonInheritedOptions {
791        pub fn set_keep_alive_idle_time(&mut self, value: u64) {
792            self.0.keep_alive_idle_time.store(value, Relaxed);
793        }
794
795        pub fn set_hop_limit(&mut self, value: u8) {
796            self.0.hop_limit.store(value, Relaxed);
797        }
798
799        pub fn set_receive_buffer_size(&mut self, value: usize) {
800            self.0.receive_buffer_size.store(value, Relaxed);
801        }
802
803        pub fn set_send_buffer_size(&mut self, value: usize) {
804            self.0.send_buffer_size.store(value, Relaxed);
805        }
806
807        pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
808            // Manually inherit socket options from listener. We only have to
809            // do this on platforms that don't already do this automatically
810            // and only if a specific value was explicitly set on the listener.
811
812            let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
813            if receive_buffer_size > 0 {
814                // Ignore potential error.
815                _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
816            }
817
818            let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
819            if send_buffer_size > 0 {
820                // Ignore potential error.
821                _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
822            }
823
824            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
825            if family == SocketAddressFamily::Ipv6 {
826                let hop_limit = self.0.hop_limit.load(Relaxed);
827                if hop_limit > 0 {
828                    // Ignore potential error.
829                    _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
830                }
831            }
832
833            let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
834            if keep_alive_idle_time > 0 {
835                // Ignore potential error.
836                _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
837            }
838        }
839    }
840}