wasmtime_wasi/sockets/
tcp.rs

1use crate::p2::P2TcpStreamingState;
2use crate::runtime::with_ambient_tokio_runtime;
3use crate::sockets::util::{
4    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
5    is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
6    set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
7    set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
8};
9use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
10use io_lifetimes::AsSocketlike as _;
11use io_lifetimes::views::SocketlikeView;
12use rustix::io::Errno;
13use rustix::net::sockopt;
14use std::fmt::Debug;
15use std::io;
16use std::mem;
17use std::net::SocketAddr;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll, Waker};
21use std::time::Duration;
22
23/// The state of a TCP socket.
24///
25/// This represents the various states a socket can be in during the
26/// activities of binding, listening, accepting, and connecting. Note that this
27/// state machine encompasses both WASIp2 and WASIp3.
28enum TcpState {
29    /// The initial state for a newly-created socket.
30    ///
31    /// From here a socket can transition to `BindStarted`, `ListenStarted`, or
32    /// `Connecting`.
33    Default(tokio::net::TcpSocket),
34
35    /// A state indicating that a bind has been started and must be finished
36    /// subsequently with `finish_bind`.
37    ///
38    /// From here a socket can transition to `Bound`.
39    BindStarted(tokio::net::TcpSocket),
40
41    /// Binding finished. The socket has an address but is not yet listening for
42    /// connections.
43    ///
44    /// From here a socket can transition to `ListenStarted`, or `Connecting`.
45    Bound(tokio::net::TcpSocket),
46
47    /// Listening on a socket has started and must be completed with
48    /// `finish_listen`.
49    ///
50    /// From here a socket can transition to `Listening`.
51    ListenStarted(tokio::net::TcpSocket),
52
53    /// The socket is now listening and waiting for an incoming connection.
54    ///
55    /// Sockets will not leave this state.
56    Listening {
57        /// The raw tokio-basd TCP listener managing the underyling socket.
58        listener: Arc<tokio::net::TcpListener>,
59
60        /// The last-accepted connection, set during the `ready` method and read
61        /// during the `accept` method. Note that this is only used for WASIp2
62        /// at this time.
63        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
64    },
65
66    /// An outgoing connection is started.
67    ///
68    /// This is created via the `start_connect` method. The payload here is an
69    /// optionally-specified owned future for the result of the connect. In
70    /// WASIp2 the future lives here, but in WASIp3 it lives on the event loop
71    /// so this is `None`.
72    ///
73    /// From here a socket can transition to `ConnectReady` or `Connected`.
74    Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
75
76    /// A connection via `Connecting` has completed.
77    ///
78    /// This is present for WASIp2 where the `Connecting` state stores `Some` of
79    /// a future, and the result of that future is recorded here when it
80    /// finishes as part of the `ready` method.
81    ///
82    /// From here a socket can transition to `Connected`.
83    ConnectReady(io::Result<tokio::net::TcpStream>),
84
85    /// A connection has been established.
86    ///
87    /// This is created either via `finish_connect` or for freshly accepted
88    /// sockets from a TCP listener.
89    ///
90    /// From here a socket can transition to `Receiving` or `P2Streaming`.
91    Connected(Arc<tokio::net::TcpStream>),
92
93    /// A connection has been established and `receive` has been called.
94    ///
95    /// A socket will not transition out of this state.
96    #[cfg(feature = "p3")]
97    Receiving(Arc<tokio::net::TcpStream>),
98
99    /// This is a WASIp2-bound socket which stores some extra state for
100    /// read/write streams to handle TCP shutdown.
101    ///
102    /// A socket will not transition out of this state.
103    P2Streaming(Box<P2TcpStreamingState>),
104
105    /// This is not actually a socket but a deferred error.
106    ///
107    /// This error came out of `accept` and is deferred until the socket is
108    /// operated on.
109    #[cfg(feature = "p3")]
110    Error(io::Error),
111
112    /// The socket is closed and no more operations can be performed.
113    Closed,
114}
115
116impl Debug for TcpState {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Default(_) => f.debug_tuple("Default").finish(),
120            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
121            Self::Bound(_) => f.debug_tuple("Bound").finish(),
122            Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
123            Self::Listening { .. } => f.debug_tuple("Listening").finish(),
124            Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
125            Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
126            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
127            #[cfg(feature = "p3")]
128            Self::Receiving { .. } => f.debug_tuple("Receiving").finish(),
129            Self::P2Streaming(_) => f.debug_tuple("P2Streaming").finish(),
130            #[cfg(feature = "p3")]
131            Self::Error(..) => f.debug_tuple("Error").finish(),
132            Self::Closed => write!(f, "Closed"),
133        }
134    }
135}
136
137/// A host TCP socket, plus associated bookkeeping.
138pub struct TcpSocket {
139    /// The current state in the bind/listen/accept/connect progression.
140    tcp_state: TcpState,
141
142    /// The desired listen queue size.
143    listen_backlog_size: u32,
144
145    family: SocketAddressFamily,
146
147    options: NonInheritedOptions,
148}
149
150impl TcpSocket {
151    /// Create a new socket in the given family.
152    pub(crate) fn new(
153        ctx: &WasiSocketsCtx,
154        family: SocketAddressFamily,
155    ) -> Result<Self, ErrorCode> {
156        ctx.allowed_network_uses.check_allowed_tcp()?;
157
158        with_ambient_tokio_runtime(|| {
159            let socket = match family {
160                SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
161                SocketAddressFamily::Ipv6 => {
162                    let socket = tokio::net::TcpSocket::new_v6()?;
163                    sockopt::set_ipv6_v6only(&socket, true)?;
164                    socket
165                }
166            };
167
168            Ok(Self::from_state(TcpState::Default(socket), family))
169        })
170    }
171
172    #[cfg(feature = "p3")]
173    pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
174        TcpSocket::from_state(TcpState::Error(err), family)
175    }
176
177    /// Creates a new socket with the `result` of an accepted socket from a
178    /// `TcpListener`.
179    ///
180    /// This will handle the `result` internally and `result` should be the raw
181    /// result from a TCP listen operation.
182    pub(crate) fn new_accept(
183        result: io::Result<tokio::net::TcpStream>,
184        options: &NonInheritedOptions,
185        family: SocketAddressFamily,
186    ) -> io::Result<Self> {
187        let client = result.map_err(|err| match Errno::from_io_error(&err) {
188            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
189            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
190            // > or the service provider is still processing a callback function.
191            //
192            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
193            // because in POSIX this error is only returned by a non-blocking
194            // `connect` and wasi-sockets has a different solution for that.
195            #[cfg(windows)]
196            Some(Errno::INPROGRESS) => Errno::INTR.into(),
197
198            // Normalize Linux' non-standard behavior.
199            //
200            // From https://man7.org/linux/man-pages/man2/accept.2.html:
201            // > Linux accept() passes already-pending network errors on the
202            // > new socket as an error code from accept(). This behavior
203            // > differs from other BSD socket implementations. (...)
204            #[cfg(target_os = "linux")]
205            Some(
206                Errno::CONNRESET
207                | Errno::NETRESET
208                | Errno::HOSTUNREACH
209                | Errno::HOSTDOWN
210                | Errno::NETDOWN
211                | Errno::NETUNREACH
212                | Errno::PROTO
213                | Errno::NOPROTOOPT
214                | Errno::NONET
215                | Errno::OPNOTSUPP,
216            ) => Errno::CONNABORTED.into(),
217
218            _ => err,
219        })?;
220        options.apply(family, &client);
221        Ok(Self::from_state(
222            TcpState::Connected(Arc::new(client)),
223            family,
224        ))
225    }
226
227    /// Create a `TcpSocket` from an existing socket.
228    fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
229        Self {
230            tcp_state: state,
231            listen_backlog_size: DEFAULT_TCP_BACKLOG,
232            family,
233            options: Default::default(),
234        }
235    }
236
237    pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
238        match &self.tcp_state {
239            TcpState::Default(socket)
240            | TcpState::BindStarted(socket)
241            | TcpState::Bound(socket)
242            | TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
243            TcpState::Connected(stream) => Ok(stream.as_socketlike_view()),
244            #[cfg(feature = "p3")]
245            TcpState::Receiving(stream) => Ok(stream.as_socketlike_view()),
246            TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
247            TcpState::P2Streaming(state) => Ok(state.stream.as_socketlike_view()),
248            TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
249                Err(ErrorCode::InvalidState)
250            }
251            #[cfg(feature = "p3")]
252            TcpState::Error(err) => Err(err.into()),
253        }
254    }
255
256    pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
257        let ip = addr.ip();
258        if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
259            return Err(ErrorCode::InvalidArgument);
260        }
261        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
262            TcpState::Default(sock) => {
263                if let Err(err) = tcp_bind(&sock, addr) {
264                    self.tcp_state = TcpState::Default(sock);
265                    Err(err)
266                } else {
267                    self.tcp_state = TcpState::BindStarted(sock);
268                    Ok(())
269                }
270            }
271            tcp_state => {
272                self.tcp_state = tcp_state;
273                Err(ErrorCode::InvalidState)
274            }
275        }
276    }
277
278    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
279        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
280            TcpState::BindStarted(socket) => {
281                self.tcp_state = TcpState::Bound(socket);
282                Ok(())
283            }
284            current_state => {
285                // Reset the state so that the outside world doesn't see this socket as closed
286                self.tcp_state = current_state;
287                Err(ErrorCode::NotInProgress)
288            }
289        }
290    }
291
292    pub(crate) fn start_connect(
293        &mut self,
294        addr: &SocketAddr,
295    ) -> Result<tokio::net::TcpSocket, ErrorCode> {
296        match self.tcp_state {
297            TcpState::Default(..) | TcpState::Bound(..) => {}
298            TcpState::Connecting(..) => {
299                return Err(ErrorCode::ConcurrencyConflict);
300            }
301            _ => return Err(ErrorCode::InvalidState),
302        };
303
304        if !is_valid_unicast_address(addr.ip())
305            || !is_valid_remote_address(*addr)
306            || !is_valid_address_family(addr.ip(), self.family)
307        {
308            return Err(ErrorCode::InvalidArgument);
309        };
310
311        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
312            mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
313        else {
314            unreachable!();
315        };
316
317        Ok(tokio_socket)
318    }
319
320    /// For WASIp2 this is used to record the actual connection future as part
321    /// of `start_connect` within this socket state.
322    pub(crate) fn set_pending_connect(
323        &mut self,
324        future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
325    ) -> Result<(), ErrorCode> {
326        match &mut self.tcp_state {
327            TcpState::Connecting(slot @ None) => {
328                *slot = Some(Box::pin(future));
329                Ok(())
330            }
331            _ => Err(ErrorCode::InvalidState),
332        }
333    }
334
335    /// For WASIp2 this retreives the result from the future passed to
336    /// `set_pending_connect`.
337    ///
338    /// Return states here are:
339    ///
340    /// * `Ok(Some(res))` - where `res` is the result of the connect operation.
341    /// * `Ok(None)` - the connect operation isn't ready yet.
342    /// * `Err(e)` - a connect operation is not in progress.
343    pub(crate) fn take_pending_connect(
344        &mut self,
345    ) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
346        match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
347            TcpState::ConnectReady(result) => Ok(Some(result)),
348            TcpState::Connecting(Some(mut future)) => {
349                let mut cx = Context::from_waker(Waker::noop());
350                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
351                    Poll::Ready(result) => Ok(Some(result)),
352                    Poll::Pending => {
353                        self.tcp_state = TcpState::Connecting(Some(future));
354                        Ok(None)
355                    }
356                }
357            }
358            current_state => {
359                self.tcp_state = current_state;
360                Err(ErrorCode::NotInProgress)
361            }
362        }
363    }
364
365    pub(crate) fn finish_connect(
366        &mut self,
367        result: io::Result<tokio::net::TcpStream>,
368    ) -> Result<(), ErrorCode> {
369        if !matches!(self.tcp_state, TcpState::Connecting(None)) {
370            return Err(ErrorCode::InvalidState);
371        }
372        match result {
373            Ok(stream) => {
374                self.tcp_state = TcpState::Connected(Arc::new(stream));
375                Ok(())
376            }
377            Err(err) => {
378                self.tcp_state = TcpState::Closed;
379                Err(ErrorCode::from(err))
380            }
381        }
382    }
383
384    pub(crate) fn start_listen(&mut self) -> Result<(), ErrorCode> {
385        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
386            TcpState::Bound(tokio_socket) => {
387                self.tcp_state = TcpState::ListenStarted(tokio_socket);
388                Ok(())
389            }
390            previous_state => {
391                self.tcp_state = previous_state;
392                Err(ErrorCode::InvalidState)
393            }
394        }
395    }
396
397    pub(crate) fn finish_listen(&mut self) -> Result<(), ErrorCode> {
398        let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
399            TcpState::ListenStarted(tokio_socket) => tokio_socket,
400            previous_state => {
401                self.tcp_state = previous_state;
402                return Err(ErrorCode::NotInProgress);
403            }
404        };
405
406        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
407            Ok(listener) => {
408                self.tcp_state = TcpState::Listening {
409                    listener: Arc::new(listener),
410                    pending_accept: None,
411                };
412                Ok(())
413            }
414            Err(err) => {
415                self.tcp_state = TcpState::Closed;
416
417                Err(match Errno::from_io_error(&err) {
418                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
419                    // According to the docs, `listen` can return EMFILE on Windows.
420                    // This is odd, because we're not trying to create a new socket
421                    // or file descriptor of any kind. So we rewrite it to less
422                    // surprising error code.
423                    //
424                    // At the time of writing, this behavior has never been experimentally
425                    // observed by any of the wasmtime authors, so we're relying fully
426                    // on Microsoft's documentation here.
427                    #[cfg(windows)]
428                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
429
430                    _ => err.into(),
431                })
432            }
433        }
434    }
435
436    pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
437        let TcpState::Listening {
438            listener,
439            pending_accept,
440        } = &mut self.tcp_state
441        else {
442            return Err(ErrorCode::InvalidState);
443        };
444
445        let result = match pending_accept.take() {
446            Some(result) => result,
447            None => {
448                let mut cx = std::task::Context::from_waker(Waker::noop());
449                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
450                    .map_ok(|(stream, _)| stream)
451                {
452                    Poll::Ready(result) => result,
453                    Poll::Pending => return Ok(None),
454                }
455            }
456        };
457
458        Ok(Some(Self::new_accept(result, &self.options, self.family)?))
459    }
460
461    #[cfg(feature = "p3")]
462    pub(crate) fn start_receive(&mut self) -> Option<&Arc<tokio::net::TcpStream>> {
463        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
464            TcpState::Connected(stream) => {
465                self.tcp_state = TcpState::Receiving(stream);
466                Some(self.tcp_stream_arc().unwrap())
467            }
468            prev => {
469                self.tcp_state = prev;
470                None
471            }
472        }
473    }
474
475    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
476        match &self.tcp_state {
477            TcpState::Bound(socket) => Ok(socket.local_addr()?),
478            TcpState::Connected(stream) => Ok(stream.local_addr()?),
479            #[cfg(feature = "p3")]
480            TcpState::Receiving(stream) => Ok(stream.local_addr()?),
481            TcpState::P2Streaming(state) => Ok(state.stream.local_addr()?),
482            TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
483            #[cfg(feature = "p3")]
484            TcpState::Error(err) => Err(err.into()),
485            _ => Err(ErrorCode::InvalidState),
486        }
487    }
488
489    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
490        let stream = self.tcp_stream_arc()?;
491        let addr = stream.peer_addr()?;
492        Ok(addr)
493    }
494
495    pub(crate) fn is_listening(&self) -> bool {
496        matches!(self.tcp_state, TcpState::Listening { .. })
497    }
498
499    pub(crate) fn address_family(&self) -> SocketAddressFamily {
500        self.family
501    }
502
503    pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
504        const MIN_BACKLOG: u32 = 1;
505        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
506
507        if value == 0 {
508            return Err(ErrorCode::InvalidArgument);
509        }
510        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
511        let value = value
512            .try_into()
513            .unwrap_or(MAX_BACKLOG)
514            .clamp(MIN_BACKLOG, MAX_BACKLOG);
515        match &self.tcp_state {
516            TcpState::Default(..) | TcpState::Bound(..) => {
517                // Socket not listening yet. Stash value for first invocation to `listen`.
518                self.listen_backlog_size = value;
519                Ok(())
520            }
521            TcpState::Listening { listener, .. } => {
522                // Try to update the backlog by calling `listen` again.
523                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
524                if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
525                    return Err(ErrorCode::NotSupported);
526                }
527                self.listen_backlog_size = value;
528                Ok(())
529            }
530            #[cfg(feature = "p3")]
531            TcpState::Error(err) => Err(err.into()),
532            _ => Err(ErrorCode::InvalidState),
533        }
534    }
535
536    pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
537        let fd = &*self.as_std_view()?;
538        let v = sockopt::socket_keepalive(fd)?;
539        Ok(v)
540    }
541
542    pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
543        let fd = &*self.as_std_view()?;
544        sockopt::set_socket_keepalive(fd, value)?;
545        Ok(())
546    }
547
548    pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
549        let fd = &*self.as_std_view()?;
550        let v = sockopt::tcp_keepidle(fd)?;
551        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
552    }
553
554    pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
555        let value = {
556            let fd = self.as_std_view()?;
557            set_keep_alive_idle_time(&*fd, value)?
558        };
559        self.options.set_keep_alive_idle_time(value);
560        Ok(())
561    }
562
563    pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
564        let fd = &*self.as_std_view()?;
565        let v = sockopt::tcp_keepintvl(fd)?;
566        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
567    }
568
569    pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
570        let fd = &*self.as_std_view()?;
571        set_keep_alive_interval(fd, Duration::from_nanos(value))?;
572        Ok(())
573    }
574
575    pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
576        let fd = &*self.as_std_view()?;
577        let v = sockopt::tcp_keepcnt(fd)?;
578        Ok(v)
579    }
580
581    pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
582        let fd = &*self.as_std_view()?;
583        set_keep_alive_count(fd, value)?;
584        Ok(())
585    }
586
587    pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
588        let fd = &*self.as_std_view()?;
589        let n = get_unicast_hop_limit(fd, self.family)?;
590        Ok(n)
591    }
592
593    pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
594        {
595            let fd = &*self.as_std_view()?;
596            set_unicast_hop_limit(fd, self.family, value)?;
597        }
598        self.options.set_hop_limit(value);
599        Ok(())
600    }
601
602    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
603        let fd = &*self.as_std_view()?;
604        let n = receive_buffer_size(fd)?;
605        Ok(n)
606    }
607
608    pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
609        let res = {
610            let fd = &*self.as_std_view()?;
611            set_receive_buffer_size(fd, value)?
612        };
613        self.options.set_receive_buffer_size(res);
614        Ok(())
615    }
616
617    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
618        let fd = &*self.as_std_view()?;
619        let n = send_buffer_size(fd)?;
620        Ok(n)
621    }
622
623    pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
624        let res = {
625            let fd = &*self.as_std_view()?;
626            set_send_buffer_size(fd, value)?
627        };
628        self.options.set_send_buffer_size(res);
629        Ok(())
630    }
631
632    #[cfg(feature = "p3")]
633    pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
634        &self.options
635    }
636
637    #[cfg(feature = "p3")]
638    pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
639        match &self.tcp_state {
640            TcpState::Listening { listener, .. } => Ok(listener),
641            #[cfg(feature = "p3")]
642            TcpState::Error(err) => Err(err.into()),
643            _ => Err(ErrorCode::InvalidState),
644        }
645    }
646
647    pub(crate) fn tcp_stream_arc(&self) -> Result<&Arc<tokio::net::TcpStream>, ErrorCode> {
648        match &self.tcp_state {
649            TcpState::Connected(socket) => Ok(socket),
650            #[cfg(feature = "p3")]
651            TcpState::Receiving(socket) => Ok(socket),
652            TcpState::P2Streaming(state) => Ok(&state.stream),
653            #[cfg(feature = "p3")]
654            TcpState::Error(err) => Err(err.into()),
655            _ => Err(ErrorCode::InvalidState),
656        }
657    }
658
659    pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
660        match &self.tcp_state {
661            TcpState::P2Streaming(state) => Ok(state),
662            #[cfg(feature = "p3")]
663            TcpState::Error(err) => Err(err.into()),
664            _ => Err(ErrorCode::InvalidState),
665        }
666    }
667
668    pub(crate) fn set_p2_streaming_state(
669        &mut self,
670        state: P2TcpStreamingState,
671    ) -> Result<(), ErrorCode> {
672        if !matches!(self.tcp_state, TcpState::Connected(_)) {
673            return Err(ErrorCode::InvalidState);
674        }
675        self.tcp_state = TcpState::P2Streaming(Box::new(state));
676        Ok(())
677    }
678
679    /// Used for `Pollable` in the WASIp2 implementation this awaits the socket
680    /// to be connected, if in the connecting state, or for a TCP accept to be
681    /// ready, if this is in the listening state.
682    ///
683    /// For all other states this method immediately returns.
684    pub(crate) async fn ready(&mut self) {
685        match &mut self.tcp_state {
686            TcpState::Default(..)
687            | TcpState::BindStarted(..)
688            | TcpState::Bound(..)
689            | TcpState::ListenStarted(..)
690            | TcpState::ConnectReady(..)
691            | TcpState::Closed
692            | TcpState::Connected { .. }
693            | TcpState::Connecting(None)
694            | TcpState::Listening {
695                pending_accept: Some(_),
696                ..
697            }
698            | TcpState::P2Streaming(_) => {}
699
700            #[cfg(feature = "p3")]
701            TcpState::Receiving(_) | TcpState::Error(_) => {}
702
703            TcpState::Connecting(Some(future)) => {
704                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
705            }
706
707            TcpState::Listening {
708                listener,
709                pending_accept: slot @ None,
710            } => {
711                let result = futures::future::poll_fn(|cx| {
712                    listener.poll_accept(cx).map_ok(|(stream, _)| stream)
713                })
714                .await;
715                *slot = Some(result);
716            }
717        }
718    }
719}
720
721#[cfg(not(target_os = "macos"))]
722pub use inherits_option::*;
723#[cfg(not(target_os = "macos"))]
724mod inherits_option {
725    use crate::sockets::SocketAddressFamily;
726    use tokio::net::TcpStream;
727
728    #[derive(Default, Clone)]
729    pub struct NonInheritedOptions;
730
731    impl NonInheritedOptions {
732        pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
733
734        pub fn set_hop_limit(&mut self, _value: u8) {}
735
736        pub fn set_receive_buffer_size(&mut self, _value: usize) {}
737
738        pub fn set_send_buffer_size(&mut self, _value: usize) {}
739
740        pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
741    }
742}
743
744#[cfg(target_os = "macos")]
745pub use does_not_inherit_options::*;
746#[cfg(target_os = "macos")]
747mod does_not_inherit_options {
748    use crate::sockets::SocketAddressFamily;
749    use rustix::net::sockopt;
750    use std::sync::Arc;
751    use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
752    use std::time::Duration;
753    use tokio::net::TcpStream;
754
755    // The socket options below are not automatically inherited from the listener
756    // on all platforms. So we keep track of which options have been explicitly
757    // set and manually apply those values to newly accepted clients.
758    #[derive(Default, Clone)]
759    pub struct NonInheritedOptions(Arc<Inner>);
760
761    #[derive(Default)]
762    struct Inner {
763        receive_buffer_size: AtomicUsize,
764        send_buffer_size: AtomicUsize,
765        hop_limit: AtomicU8,
766        keep_alive_idle_time: AtomicU64, // nanoseconds
767    }
768
769    impl NonInheritedOptions {
770        pub fn set_keep_alive_idle_time(&mut self, value: u64) {
771            self.0.keep_alive_idle_time.store(value, Relaxed);
772        }
773
774        pub fn set_hop_limit(&mut self, value: u8) {
775            self.0.hop_limit.store(value, Relaxed);
776        }
777
778        pub fn set_receive_buffer_size(&mut self, value: usize) {
779            self.0.receive_buffer_size.store(value, Relaxed);
780        }
781
782        pub fn set_send_buffer_size(&mut self, value: usize) {
783            self.0.send_buffer_size.store(value, Relaxed);
784        }
785
786        pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
787            // Manually inherit socket options from listener. We only have to
788            // do this on platforms that don't already do this automatically
789            // and only if a specific value was explicitly set on the listener.
790
791            let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
792            if receive_buffer_size > 0 {
793                // Ignore potential error.
794                _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
795            }
796
797            let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
798            if send_buffer_size > 0 {
799                // Ignore potential error.
800                _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
801            }
802
803            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
804            if family == SocketAddressFamily::Ipv6 {
805                let hop_limit = self.0.hop_limit.load(Relaxed);
806                if hop_limit > 0 {
807                    // Ignore potential error.
808                    _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
809                }
810            }
811
812            let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
813            if keep_alive_idle_time > 0 {
814                // Ignore potential error.
815                _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
816            }
817        }
818    }
819}