Skip to main content

wasmtime_wasi/sockets/
tcp.rs

1use crate::p2::P2TcpStreamingState;
2use crate::runtime::with_ambient_tokio_runtime;
3use crate::sockets::util::{
4    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
5    is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
6    set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
7    set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
8};
9use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
10use io_lifetimes::AsSocketlike as _;
11use io_lifetimes::views::SocketlikeView;
12use rustix::io::Errno;
13use rustix::net::sockopt;
14use std::fmt::Debug;
15use std::io;
16use std::mem;
17use std::net::SocketAddr;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll, Waker};
21use std::time::Duration;
22
23/// The state of a TCP socket.
24///
25/// This represents the various states a socket can be in during the
26/// activities of binding, listening, accepting, and connecting. Note that this
27/// state machine encompasses both WASIp2 and WASIp3.
28enum TcpState {
29    /// The initial state for a newly-created socket.
30    ///
31    /// From here a socket can transition to `BindStarted`, `ListenStarted`, or
32    /// `Connecting`.
33    Default(tokio::net::TcpSocket),
34
35    /// A state indicating that a bind has been started and must be finished
36    /// subsequently with `finish_bind`.
37    ///
38    /// From here a socket can transition to `Bound`.
39    BindStarted(tokio::net::TcpSocket),
40
41    /// Binding finished. The socket has an address but is not yet listening for
42    /// connections.
43    ///
44    /// From here a socket can transition to `ListenStarted`, or `Connecting`.
45    Bound(tokio::net::TcpSocket),
46
47    /// Listening on a socket has started and must be completed with
48    /// `finish_listen`.
49    ///
50    /// From here a socket can transition to `Listening`.
51    ListenStarted(tokio::net::TcpSocket),
52
53    /// The socket is now listening and waiting for an incoming connection.
54    ///
55    /// Sockets will not leave this state.
56    Listening {
57        /// The raw tokio-basd TCP listener managing the underlying socket.
58        listener: Arc<tokio::net::TcpListener>,
59
60        /// The last-accepted connection, set during the `ready` method and read
61        /// during the `accept` method. Note that this is only used for WASIp2
62        /// at this time.
63        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
64    },
65
66    /// An outgoing connection is started.
67    ///
68    /// This is created via the `start_connect` method. The payload here is an
69    /// optionally-specified owned future for the result of the connect. In
70    /// WASIp2 the future lives here, but in WASIp3 it lives on the event loop
71    /// so this is `None`.
72    ///
73    /// From here a socket can transition to `ConnectReady` or `Connected`.
74    Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
75
76    /// A connection via `Connecting` has completed.
77    ///
78    /// This is present for WASIp2 where the `Connecting` state stores `Some` of
79    /// a future, and the result of that future is recorded here when it
80    /// finishes as part of the `ready` method.
81    ///
82    /// From here a socket can transition to `Connected`.
83    ConnectReady(io::Result<tokio::net::TcpStream>),
84
85    /// A connection has been established.
86    ///
87    /// This is created either via `finish_connect` or for freshly accepted
88    /// sockets from a TCP listener.
89    ///
90    /// A socket will not transition out of this state.
91    Connected {
92        stream: Arc<tokio::net::TcpStream>,
93        taken_streams: TakenStreams,
94        p2_state: Option<P2TcpStreamingState>,
95    },
96
97    /// This is not actually a socket but a deferred error.
98    ///
99    /// This error came out of `accept` and is deferred until the socket is
100    /// operated on.
101    #[cfg(feature = "p3")]
102    Error(io::Error),
103
104    /// The socket is closed and no more operations can be performed.
105    Closed,
106}
107impl TcpState {
108    fn connected(stream: tokio::net::TcpStream) -> Self {
109        TcpState::Connected {
110            stream: Arc::new(stream),
111            taken_streams: TakenStreams {
112                receive: false,
113                send: false,
114            },
115            p2_state: None,
116        }
117    }
118}
119impl Debug for TcpState {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Self::Default(_) => f.debug_tuple("Default").finish(),
123            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
124            Self::Bound(_) => f.debug_tuple("Bound").finish(),
125            Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
126            Self::Listening { .. } => f.debug_tuple("Listening").finish(),
127            Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
128            Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
129            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
130            #[cfg(feature = "p3")]
131            Self::Error(..) => f.debug_tuple("Error").finish(),
132            Self::Closed => write!(f, "Closed"),
133        }
134    }
135}
136
137struct TakenStreams {
138    receive: bool,
139    send: bool,
140}
141
142/// A host TCP socket, plus associated bookkeeping.
143pub struct TcpSocket {
144    /// The current state in the bind/listen/accept/connect progression.
145    tcp_state: TcpState,
146
147    /// The desired listen queue size.
148    listen_backlog_size: u32,
149
150    family: SocketAddressFamily,
151
152    options: NonInheritedOptions,
153}
154
155impl TcpSocket {
156    /// Create a new socket in the given family.
157    pub(crate) fn new(
158        ctx: &WasiSocketsCtx,
159        family: SocketAddressFamily,
160    ) -> Result<Self, ErrorCode> {
161        ctx.allowed_network_uses.check_allowed_tcp()?;
162
163        with_ambient_tokio_runtime(|| {
164            let socket = match family {
165                SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
166                SocketAddressFamily::Ipv6 => {
167                    let socket = tokio::net::TcpSocket::new_v6()?;
168                    sockopt::set_ipv6_v6only(&socket, true)?;
169                    socket
170                }
171            };
172
173            Ok(Self::from_state(TcpState::Default(socket), family))
174        })
175    }
176
177    #[cfg(feature = "p3")]
178    pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
179        TcpSocket::from_state(TcpState::Error(err), family)
180    }
181
182    /// Creates a new socket with the `result` of an accepted socket from a
183    /// `TcpListener`.
184    ///
185    /// This will handle the `result` internally and `result` should be the raw
186    /// result from a TCP listen operation.
187    pub(crate) fn new_accept(
188        result: io::Result<tokio::net::TcpStream>,
189        options: &NonInheritedOptions,
190        family: SocketAddressFamily,
191    ) -> io::Result<Self> {
192        let client = result.map_err(|err| match Errno::from_io_error(&err) {
193            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
194            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
195            // > or the service provider is still processing a callback function.
196            //
197            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
198            // because in POSIX this error is only returned by a non-blocking
199            // `connect` and wasi-sockets has a different solution for that.
200            #[cfg(windows)]
201            Some(Errno::INPROGRESS) => Errno::INTR.into(),
202
203            // Normalize Linux' non-standard behavior.
204            //
205            // From https://man7.org/linux/man-pages/man2/accept.2.html:
206            // > Linux accept() passes already-pending network errors on the
207            // > new socket as an error code from accept(). This behavior
208            // > differs from other BSD socket implementations. (...)
209            #[cfg(target_os = "linux")]
210            Some(
211                Errno::CONNRESET
212                | Errno::NETRESET
213                | Errno::HOSTUNREACH
214                | Errno::HOSTDOWN
215                | Errno::NETDOWN
216                | Errno::NETUNREACH
217                | Errno::PROTO
218                | Errno::NOPROTOOPT
219                | Errno::NONET
220                | Errno::OPNOTSUPP,
221            ) => Errno::CONNABORTED.into(),
222
223            _ => err,
224        })?;
225        options.apply(family, &client);
226        Ok(Self::from_state(TcpState::connected(client), family))
227    }
228
229    /// Create a `TcpSocket` from an existing socket.
230    fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
231        Self {
232            tcp_state: state,
233            listen_backlog_size: DEFAULT_TCP_BACKLOG,
234            family,
235            options: Default::default(),
236        }
237    }
238
239    pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
240        match &self.tcp_state {
241            TcpState::Default(socket)
242            | TcpState::BindStarted(socket)
243            | TcpState::Bound(socket)
244            | TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
245            TcpState::Connected { stream, .. } => Ok(stream.as_socketlike_view()),
246            TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
247            TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
248                Err(ErrorCode::InvalidState)
249            }
250            #[cfg(feature = "p3")]
251            TcpState::Error(err) => Err(err.into()),
252        }
253    }
254
255    pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
256        let ip = addr.ip();
257        if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
258            return Err(ErrorCode::InvalidArgument);
259        }
260        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
261            TcpState::Default(sock) => {
262                if let Err(err) = tcp_bind(&sock, addr) {
263                    self.tcp_state = TcpState::Default(sock);
264                    Err(err)
265                } else {
266                    self.tcp_state = TcpState::BindStarted(sock);
267                    Ok(())
268                }
269            }
270            tcp_state => {
271                self.tcp_state = tcp_state;
272                Err(ErrorCode::InvalidState)
273            }
274        }
275    }
276
277    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
278        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
279            TcpState::BindStarted(socket) => {
280                self.tcp_state = TcpState::Bound(socket);
281                Ok(())
282            }
283            current_state => {
284                // Reset the state so that the outside world doesn't see this socket as closed
285                self.tcp_state = current_state;
286                Err(ErrorCode::NotInProgress)
287            }
288        }
289    }
290
291    pub(crate) fn start_connect(
292        &mut self,
293        addr: &SocketAddr,
294    ) -> Result<tokio::net::TcpSocket, ErrorCode> {
295        match self.tcp_state {
296            TcpState::Default(..) | TcpState::Bound(..) => {}
297            TcpState::Connecting(..) => {
298                return Err(ErrorCode::ConcurrencyConflict);
299            }
300            _ => return Err(ErrorCode::InvalidState),
301        };
302
303        if !is_valid_unicast_address(addr.ip())
304            || !is_valid_remote_address(*addr)
305            || !is_valid_address_family(addr.ip(), self.family)
306        {
307            return Err(ErrorCode::InvalidArgument);
308        };
309
310        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
311            mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
312        else {
313            unreachable!();
314        };
315
316        Ok(tokio_socket)
317    }
318
319    /// For WASIp2 this is used to record the actual connection future as part
320    /// of `start_connect` within this socket state.
321    pub(crate) fn set_pending_connect(
322        &mut self,
323        future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
324    ) -> Result<(), ErrorCode> {
325        match &mut self.tcp_state {
326            TcpState::Connecting(slot @ None) => {
327                *slot = Some(Box::pin(future));
328                Ok(())
329            }
330            _ => Err(ErrorCode::InvalidState),
331        }
332    }
333
334    /// For WASIp2 this retrieves the result from the future passed to
335    /// `set_pending_connect`.
336    ///
337    /// Return states here are:
338    ///
339    /// * `Ok(Some(res))` - where `res` is the result of the connect operation.
340    /// * `Ok(None)` - the connect operation isn't ready yet.
341    /// * `Err(e)` - a connect operation is not in progress.
342    pub(crate) fn take_pending_connect(
343        &mut self,
344    ) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
345        match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
346            TcpState::ConnectReady(result) => Ok(Some(result)),
347            TcpState::Connecting(Some(mut future)) => {
348                let mut cx = Context::from_waker(Waker::noop());
349                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
350                    Poll::Ready(result) => Ok(Some(result)),
351                    Poll::Pending => {
352                        self.tcp_state = TcpState::Connecting(Some(future));
353                        Ok(None)
354                    }
355                }
356            }
357            current_state => {
358                self.tcp_state = current_state;
359                Err(ErrorCode::NotInProgress)
360            }
361        }
362    }
363
364    pub(crate) fn finish_connect(
365        &mut self,
366        result: io::Result<tokio::net::TcpStream>,
367    ) -> Result<(), ErrorCode> {
368        if !matches!(self.tcp_state, TcpState::Connecting(None)) {
369            return Err(ErrorCode::InvalidState);
370        }
371        match result {
372            Ok(stream) => {
373                self.tcp_state = TcpState::connected(stream);
374                Ok(())
375            }
376            Err(err) => {
377                self.tcp_state = TcpState::Closed;
378                Err(ErrorCode::from(err))
379            }
380        }
381    }
382
383    /// Start listening using p2 semantics. (no implicit bind)
384    pub(crate) fn start_listen_p2(&mut self) -> Result<(), ErrorCode> {
385        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
386            TcpState::Bound(tokio_socket) => {
387                self.tcp_state = TcpState::ListenStarted(tokio_socket);
388                Ok(())
389            }
390            previous_state => {
391                self.tcp_state = previous_state;
392                Err(ErrorCode::InvalidState)
393            }
394        }
395    }
396
397    pub(crate) fn finish_listen_p2(&mut self) -> Result<(), ErrorCode> {
398        let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
399            TcpState::ListenStarted(tokio_socket) => tokio_socket,
400            previous_state => {
401                self.tcp_state = previous_state;
402                return Err(ErrorCode::NotInProgress);
403            }
404        };
405
406        self.listen_common(tokio_socket)
407    }
408
409    /// Start listening using p3 semantics. (with implicit bind)
410    #[cfg(feature = "p3")]
411    pub(crate) fn listen_p3(&mut self) -> Result<(), ErrorCode> {
412        let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
413            TcpState::Bound(tokio_socket) => tokio_socket,
414            TcpState::Default(tokio_socket) => {
415                // Some platforms automatically perform an implicit bind as part
416                // of the `listen` syscall. However this is not ubiquitous
417                // behavior:
418                // - Linux mentions it in their docs [0] that they perform an
419                //   implicit bind. This behavior has been experimentally verified.
420                // - Windows requires a `bind` before `listen`. This is both
421                //   documented [1] and experimentally verified.
422                // - Other platforms (e.g. macOS, FreeBSD) do not explicitly
423                //   document it either way and instead leave it up to the
424                //   individual protocol to decide [2]. However, experiments
425                //   show that MacOS in fact _does_ perform an implicit bind.
426                //
427                // To ensure consistent behavior across all platforms, we
428                // perform the implicit bind ourselves here.
429                //
430                // [0]: https://man7.org/linux/man-pages/man7/ip.7.html
431                // > An ephemeral port is allocated to a socket in the following
432                // > circumstances: (...) listen(2) is called on a stream socket
433                // > that was not previously bound;
434                //
435                // [1]: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen
436                // > WSAEINVAL: The socket has not been bound with bind.
437                //
438                // [2]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/listen.html
439                // > EDESTADDRREQ: The socket is not bound to a local address,
440                // > and the protocol does not support listening on an unbound
441                // > socket.
442                let implicit_addr = crate::sockets::util::implicit_bind_addr(self.family);
443                tcp_bind(&tokio_socket, implicit_addr)?;
444                tokio_socket
445            }
446            previous_state => {
447                self.tcp_state = previous_state;
448                return Err(ErrorCode::InvalidState);
449            }
450        };
451
452        self.listen_common(tokio_socket)
453    }
454
455    fn listen_common(&mut self, tokio_socket: tokio::net::TcpSocket) -> Result<(), ErrorCode> {
456        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
457            Ok(listener) => {
458                self.tcp_state = TcpState::Listening {
459                    listener: Arc::new(listener),
460                    pending_accept: None,
461                };
462                Ok(())
463            }
464            Err(err) => {
465                self.tcp_state = TcpState::Closed;
466
467                Err(match Errno::from_io_error(&err) {
468                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
469                    // According to the docs, `listen` can return EMFILE on Windows.
470                    // This is odd, because we're not trying to create a new socket
471                    // or file descriptor of any kind. So we rewrite it to less
472                    // surprising error code.
473                    //
474                    // At the time of writing, this behavior has never been experimentally
475                    // observed by any of the wasmtime authors, so we're relying fully
476                    // on Microsoft's documentation here.
477                    #[cfg(windows)]
478                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
479
480                    _ => err.into(),
481                })
482            }
483        }
484    }
485
486    pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
487        let TcpState::Listening {
488            listener,
489            pending_accept,
490        } = &mut self.tcp_state
491        else {
492            return Err(ErrorCode::InvalidState);
493        };
494
495        let result = match pending_accept.take() {
496            Some(result) => result,
497            None => {
498                let mut cx = std::task::Context::from_waker(Waker::noop());
499                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
500                    .map_ok(|(stream, _)| stream)
501                {
502                    Poll::Ready(result) => result,
503                    Poll::Pending => return Ok(None),
504                }
505            }
506        };
507
508        Ok(Some(Self::new_accept(result, &self.options, self.family)?))
509    }
510
511    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
512        match &self.tcp_state {
513            TcpState::Bound(socket) => Ok(socket.local_addr()?),
514            TcpState::Connected { stream, .. } => Ok(stream.local_addr()?),
515            TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
516            #[cfg(feature = "p3")]
517            TcpState::Error(err) => Err(err.into()),
518            _ => Err(ErrorCode::InvalidState),
519        }
520    }
521
522    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
523        match &self.tcp_state {
524            TcpState::Connected { stream, .. } => Ok(stream.peer_addr()?),
525            #[cfg(feature = "p3")]
526            TcpState::Error(err) => Err(err.into()),
527            _ => Err(ErrorCode::InvalidState),
528        }
529    }
530
531    pub(crate) fn is_listening(&self) -> bool {
532        matches!(self.tcp_state, TcpState::Listening { .. })
533    }
534
535    pub(crate) fn address_family(&self) -> SocketAddressFamily {
536        self.family
537    }
538
539    pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
540        const MIN_BACKLOG: u32 = 1;
541        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
542
543        if value == 0 {
544            return Err(ErrorCode::InvalidArgument);
545        }
546        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
547        let value = value
548            .try_into()
549            .unwrap_or(MAX_BACKLOG)
550            .clamp(MIN_BACKLOG, MAX_BACKLOG);
551        match &self.tcp_state {
552            TcpState::Default(..) | TcpState::Bound(..) => {
553                // Socket not listening yet. Stash value for first invocation to `listen`.
554                self.listen_backlog_size = value;
555                Ok(())
556            }
557            TcpState::Listening { listener, .. } => {
558                // Try to update the backlog by calling `listen` again.
559                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
560                if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
561                    return Err(ErrorCode::NotSupported);
562                }
563                self.listen_backlog_size = value;
564                Ok(())
565            }
566            #[cfg(feature = "p3")]
567            TcpState::Error(err) => Err(err.into()),
568            _ => Err(ErrorCode::InvalidState),
569        }
570    }
571
572    pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
573        let fd = &*self.as_std_view()?;
574        let v = sockopt::socket_keepalive(fd)?;
575        Ok(v)
576    }
577
578    pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
579        let fd = &*self.as_std_view()?;
580        sockopt::set_socket_keepalive(fd, value)?;
581        Ok(())
582    }
583
584    pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
585        let fd = &*self.as_std_view()?;
586        let v = sockopt::tcp_keepidle(fd)?;
587        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
588    }
589
590    pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
591        let value = {
592            let fd = self.as_std_view()?;
593            set_keep_alive_idle_time(&*fd, value)?
594        };
595        self.options.set_keep_alive_idle_time(value);
596        Ok(())
597    }
598
599    pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
600        let fd = &*self.as_std_view()?;
601        let v = sockopt::tcp_keepintvl(fd)?;
602        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
603    }
604
605    pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
606        let fd = &*self.as_std_view()?;
607        set_keep_alive_interval(fd, Duration::from_nanos(value))?;
608        Ok(())
609    }
610
611    pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
612        let fd = &*self.as_std_view()?;
613        let v = sockopt::tcp_keepcnt(fd)?;
614        Ok(v)
615    }
616
617    pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
618        let fd = &*self.as_std_view()?;
619        set_keep_alive_count(fd, value)?;
620        Ok(())
621    }
622
623    pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
624        let fd = &*self.as_std_view()?;
625        let n = get_unicast_hop_limit(fd, self.family)?;
626        Ok(n)
627    }
628
629    pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
630        {
631            let fd = &*self.as_std_view()?;
632            set_unicast_hop_limit(fd, self.family, value)?;
633        }
634        self.options.set_hop_limit(value);
635        Ok(())
636    }
637
638    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
639        let fd = &*self.as_std_view()?;
640        let n = receive_buffer_size(fd)?;
641        Ok(n)
642    }
643
644    pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
645        let res = {
646            let fd = &*self.as_std_view()?;
647            set_receive_buffer_size(fd, value)?
648        };
649        self.options.set_receive_buffer_size(res);
650        Ok(())
651    }
652
653    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
654        let fd = &*self.as_std_view()?;
655        let n = send_buffer_size(fd)?;
656        Ok(n)
657    }
658
659    pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
660        let res = {
661            let fd = &*self.as_std_view()?;
662            set_send_buffer_size(fd, value)?
663        };
664        self.options.set_send_buffer_size(res);
665        Ok(())
666    }
667
668    #[cfg(feature = "p3")]
669    pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
670        &self.options
671    }
672
673    #[cfg(feature = "p3")]
674    pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
675        match &self.tcp_state {
676            TcpState::Listening { listener, .. } => Ok(listener),
677            #[cfg(feature = "p3")]
678            TcpState::Error(err) => Err(err.into()),
679            _ => Err(ErrorCode::InvalidState),
680        }
681    }
682
683    pub(crate) fn take_receive_stream(&mut self) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
684        self.take_stream(|s| &mut s.receive)
685    }
686
687    pub(crate) fn take_send_stream(&mut self) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
688        self.take_stream(|s| &mut s.send)
689    }
690
691    fn take_stream(
692        &mut self,
693        direction: impl FnOnce(&mut TakenStreams) -> &mut bool,
694    ) -> Result<Arc<tokio::net::TcpStream>, ErrorCode> {
695        match &mut self.tcp_state {
696            TcpState::Connected {
697                stream,
698                taken_streams,
699                ..
700            } => {
701                let taken = direction(taken_streams);
702                if *taken {
703                    return Err(ErrorCode::InvalidState);
704                }
705                *taken = true;
706                Ok(stream.clone())
707            }
708            #[cfg(feature = "p3")]
709            TcpState::Error(err) => Err((&*err).into()),
710            _ => Err(ErrorCode::InvalidState),
711        }
712    }
713
714    pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
715        match &self.tcp_state {
716            TcpState::Connected {
717                p2_state: Some(state),
718                ..
719            } => Ok(state),
720            #[cfg(feature = "p3")]
721            TcpState::Error(err) => Err(err.into()),
722            _ => Err(ErrorCode::InvalidState),
723        }
724    }
725
726    pub(crate) fn set_p2_streaming_state(
727        &mut self,
728        state: P2TcpStreamingState,
729    ) -> Result<(), ErrorCode> {
730        if let TcpState::Connected { p2_state, .. } = &mut self.tcp_state {
731            *p2_state = Some(state);
732            Ok(())
733        } else {
734            Err(ErrorCode::InvalidState)
735        }
736    }
737
738    /// Used for `Pollable` in the WASIp2 implementation this awaits the socket
739    /// to be connected, if in the connecting state, or for a TCP accept to be
740    /// ready, if this is in the listening state.
741    ///
742    /// For all other states this method immediately returns.
743    pub(crate) async fn ready(&mut self) {
744        match &mut self.tcp_state {
745            TcpState::Default(..)
746            | TcpState::BindStarted(..)
747            | TcpState::Bound(..)
748            | TcpState::ListenStarted(..)
749            | TcpState::ConnectReady(..)
750            | TcpState::Closed
751            | TcpState::Connected { .. }
752            | TcpState::Connecting(None)
753            | TcpState::Listening {
754                pending_accept: Some(_),
755                ..
756            } => {}
757
758            #[cfg(feature = "p3")]
759            TcpState::Error(_) => {}
760
761            TcpState::Connecting(Some(future)) => {
762                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
763            }
764
765            TcpState::Listening {
766                listener,
767                pending_accept: slot @ None,
768            } => {
769                let result = futures::future::poll_fn(|cx| {
770                    listener.poll_accept(cx).map_ok(|(stream, _)| stream)
771                })
772                .await;
773                *slot = Some(result);
774            }
775        }
776    }
777}
778
779#[cfg(not(target_os = "macos"))]
780pub use inherits_option::*;
781#[cfg(not(target_os = "macos"))]
782mod inherits_option {
783    use crate::sockets::SocketAddressFamily;
784    use tokio::net::TcpStream;
785
786    #[derive(Default, Clone)]
787    pub struct NonInheritedOptions;
788
789    impl NonInheritedOptions {
790        pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
791
792        pub fn set_hop_limit(&mut self, _value: u8) {}
793
794        pub fn set_receive_buffer_size(&mut self, _value: usize) {}
795
796        pub fn set_send_buffer_size(&mut self, _value: usize) {}
797
798        pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
799    }
800}
801
802#[cfg(target_os = "macos")]
803pub use does_not_inherit_options::*;
804#[cfg(target_os = "macos")]
805mod does_not_inherit_options {
806    use crate::sockets::SocketAddressFamily;
807    use rustix::net::sockopt;
808    use std::sync::Arc;
809    use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
810    use std::time::Duration;
811    use tokio::net::TcpStream;
812
813    // The socket options below are not automatically inherited from the listener
814    // on all platforms. So we keep track of which options have been explicitly
815    // set and manually apply those values to newly accepted clients.
816    #[derive(Default, Clone)]
817    pub struct NonInheritedOptions(Arc<Inner>);
818
819    #[derive(Default)]
820    struct Inner {
821        receive_buffer_size: AtomicUsize,
822        send_buffer_size: AtomicUsize,
823        hop_limit: AtomicU8,
824        keep_alive_idle_time: AtomicU64, // nanoseconds
825    }
826
827    impl NonInheritedOptions {
828        pub fn set_keep_alive_idle_time(&mut self, value: u64) {
829            self.0.keep_alive_idle_time.store(value, Relaxed);
830        }
831
832        pub fn set_hop_limit(&mut self, value: u8) {
833            self.0.hop_limit.store(value, Relaxed);
834        }
835
836        pub fn set_receive_buffer_size(&mut self, value: usize) {
837            self.0.receive_buffer_size.store(value, Relaxed);
838        }
839
840        pub fn set_send_buffer_size(&mut self, value: usize) {
841            self.0.send_buffer_size.store(value, Relaxed);
842        }
843
844        pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
845            // Manually inherit socket options from listener. We only have to
846            // do this on platforms that don't already do this automatically
847            // and only if a specific value was explicitly set on the listener.
848
849            let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
850            if receive_buffer_size > 0 {
851                // Ignore potential error.
852                _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
853            }
854
855            let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
856            if send_buffer_size > 0 {
857                // Ignore potential error.
858                _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
859            }
860
861            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
862            if family == SocketAddressFamily::Ipv6 {
863                let hop_limit = self.0.hop_limit.load(Relaxed);
864                if hop_limit > 0 {
865                    // Ignore potential error.
866                    _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
867                }
868            }
869
870            let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
871            if keep_alive_idle_time > 0 {
872                // Ignore potential error.
873                _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
874            }
875        }
876    }
877}