wasmtime_wasi/sockets/
tcp.rs

1use crate::p2::P2TcpStreamingState;
2use crate::runtime::with_ambient_tokio_runtime;
3use crate::sockets::util::{
4    ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
5    is_valid_unicast_address, receive_buffer_size, send_buffer_size, set_keep_alive_count,
6    set_keep_alive_idle_time, set_keep_alive_interval, set_receive_buffer_size,
7    set_send_buffer_size, set_unicast_hop_limit, tcp_bind,
8};
9use crate::sockets::{DEFAULT_TCP_BACKLOG, SocketAddressFamily, WasiSocketsCtx};
10use io_lifetimes::AsSocketlike as _;
11use io_lifetimes::views::SocketlikeView;
12use rustix::io::Errno;
13use rustix::net::sockopt;
14use std::fmt::Debug;
15use std::io;
16use std::mem;
17use std::net::SocketAddr;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::{Context, Poll, Waker};
21use std::time::Duration;
22
23/// The state of a TCP socket.
24///
25/// This represents the various states a socket can be in during the
26/// activities of binding, listening, accepting, and connecting. Note that this
27/// state machine encompasses both WASIp2 and WASIp3.
28enum TcpState {
29    /// The initial state for a newly-created socket.
30    ///
31    /// From here a socket can transition to `BindStarted`, `ListenStarted`, or
32    /// `Connecting`.
33    Default(tokio::net::TcpSocket),
34
35    /// A state indicating that a bind has been started and must be finished
36    /// subsequently with `finish_bind`.
37    ///
38    /// From here a socket can transition to `Bound`.
39    BindStarted(tokio::net::TcpSocket),
40
41    /// Binding finished. The socket has an address but is not yet listening for
42    /// connections.
43    ///
44    /// From here a socket can transition to `ListenStarted`, or `Connecting`.
45    Bound(tokio::net::TcpSocket),
46
47    /// Listening on a socket has started and must be completed with
48    /// `finish_listen`.
49    ///
50    /// From here a socket can transition to `Listening`.
51    ListenStarted(tokio::net::TcpSocket),
52
53    /// The socket is now listening and waiting for an incoming connection.
54    ///
55    /// Sockets will not leave this state.
56    Listening {
57        /// The raw tokio-basd TCP listener managing the underyling socket.
58        listener: Arc<tokio::net::TcpListener>,
59
60        /// The last-accepted connection, set during the `ready` method and read
61        /// during the `accept` method. Note that this is only used for WASIp2
62        /// at this time.
63        pending_accept: Option<io::Result<tokio::net::TcpStream>>,
64    },
65
66    /// An outgoing connection is started.
67    ///
68    /// This is created via the `start_connect` method. The payload here is an
69    /// optionally-specified owned future for the result of the connect. In
70    /// WASIp2 the future lives here, but in WASIp3 it lives on the event loop
71    /// so this is `None`.
72    ///
73    /// From here a socket can transition to `ConnectReady` or `Connected`.
74    Connecting(Option<Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>>),
75
76    /// A connection via `Connecting` has completed.
77    ///
78    /// This is present for WASIp2 where the `Connecting` state stores `Some` of
79    /// a future, and the result of that future is recorded here when it
80    /// finishes as part of the `ready` method.
81    ///
82    /// From here a socket can transition to `Connected`.
83    ConnectReady(io::Result<tokio::net::TcpStream>),
84
85    /// A connection has been established.
86    ///
87    /// This is created either via `finish_connect` or for freshly accepted
88    /// sockets from a TCP listener.
89    ///
90    /// From here a socket can transition to `Receiving` or `P2Streaming`.
91    Connected(Arc<tokio::net::TcpStream>),
92
93    /// A connection has been established and `receive` has been called.
94    ///
95    /// A socket will not transition out of this state.
96    #[cfg(feature = "p3")]
97    Receiving(Arc<tokio::net::TcpStream>),
98
99    /// This is a WASIp2-bound socket which stores some extra state for
100    /// read/write streams to handle TCP shutdown.
101    ///
102    /// A socket will not transition out of this state.
103    P2Streaming(Box<P2TcpStreamingState>),
104
105    /// This is not actually a socket but a deferred error.
106    ///
107    /// This error came out of `accept` and is deferred until the socket is
108    /// operated on.
109    #[cfg(feature = "p3")]
110    Error(io::Error),
111
112    /// The socket is closed and no more operations can be performed.
113    Closed,
114}
115
116impl Debug for TcpState {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Default(_) => f.debug_tuple("Default").finish(),
120            Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
121            Self::Bound(_) => f.debug_tuple("Bound").finish(),
122            Self::ListenStarted { .. } => f.debug_tuple("ListenStarted").finish(),
123            Self::Listening { .. } => f.debug_tuple("Listening").finish(),
124            Self::Connecting(..) => f.debug_tuple("Connecting").finish(),
125            Self::ConnectReady(..) => f.debug_tuple("ConnectReady").finish(),
126            Self::Connected { .. } => f.debug_tuple("Connected").finish(),
127            #[cfg(feature = "p3")]
128            Self::Receiving { .. } => f.debug_tuple("Receiving").finish(),
129            Self::P2Streaming(_) => f.debug_tuple("P2Streaming").finish(),
130            #[cfg(feature = "p3")]
131            Self::Error(..) => f.debug_tuple("Error").finish(),
132            Self::Closed => write!(f, "Closed"),
133        }
134    }
135}
136
137/// A host TCP socket, plus associated bookkeeping.
138pub struct TcpSocket {
139    /// The current state in the bind/listen/accept/connect progression.
140    tcp_state: TcpState,
141
142    /// The desired listen queue size.
143    listen_backlog_size: u32,
144
145    family: SocketAddressFamily,
146
147    options: NonInheritedOptions,
148}
149
150impl TcpSocket {
151    /// Create a new socket in the given family.
152    pub(crate) fn new(ctx: &WasiSocketsCtx, family: SocketAddressFamily) -> std::io::Result<Self> {
153        ctx.allowed_network_uses.check_allowed_tcp()?;
154
155        with_ambient_tokio_runtime(|| {
156            let socket = match family {
157                SocketAddressFamily::Ipv4 => tokio::net::TcpSocket::new_v4()?,
158                SocketAddressFamily::Ipv6 => {
159                    let socket = tokio::net::TcpSocket::new_v6()?;
160                    sockopt::set_ipv6_v6only(&socket, true)?;
161                    socket
162                }
163            };
164
165            Ok(Self::from_state(TcpState::Default(socket), family))
166        })
167    }
168
169    #[cfg(feature = "p3")]
170    pub(crate) fn new_error(err: io::Error, family: SocketAddressFamily) -> Self {
171        TcpSocket::from_state(TcpState::Error(err), family)
172    }
173
174    /// Creates a new socket with the `result` of an accepted socket from a
175    /// `TcpListener`.
176    ///
177    /// This will handle the `result` internally and `result` should be the raw
178    /// result from a TCP listen operation.
179    pub(crate) fn new_accept(
180        result: io::Result<tokio::net::TcpStream>,
181        options: &NonInheritedOptions,
182        family: SocketAddressFamily,
183    ) -> io::Result<Self> {
184        let client = result.map_err(|err| match Errno::from_io_error(&err) {
185            // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
186            // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
187            // > or the service provider is still processing a callback function.
188            //
189            // wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
190            // because in POSIX this error is only returned by a non-blocking
191            // `connect` and wasi-sockets has a different solution for that.
192            #[cfg(windows)]
193            Some(Errno::INPROGRESS) => Errno::INTR.into(),
194
195            // Normalize Linux' non-standard behavior.
196            //
197            // From https://man7.org/linux/man-pages/man2/accept.2.html:
198            // > Linux accept() passes already-pending network errors on the
199            // > new socket as an error code from accept(). This behavior
200            // > differs from other BSD socket implementations. (...)
201            #[cfg(target_os = "linux")]
202            Some(
203                Errno::CONNRESET
204                | Errno::NETRESET
205                | Errno::HOSTUNREACH
206                | Errno::HOSTDOWN
207                | Errno::NETDOWN
208                | Errno::NETUNREACH
209                | Errno::PROTO
210                | Errno::NOPROTOOPT
211                | Errno::NONET
212                | Errno::OPNOTSUPP,
213            ) => Errno::CONNABORTED.into(),
214
215            _ => err,
216        })?;
217        options.apply(family, &client);
218        Ok(Self::from_state(
219            TcpState::Connected(Arc::new(client)),
220            family,
221        ))
222    }
223
224    /// Create a `TcpSocket` from an existing socket.
225    fn from_state(state: TcpState, family: SocketAddressFamily) -> Self {
226        Self {
227            tcp_state: state,
228            listen_backlog_size: DEFAULT_TCP_BACKLOG,
229            family,
230            options: Default::default(),
231        }
232    }
233
234    pub(crate) fn as_std_view(&self) -> Result<SocketlikeView<'_, std::net::TcpStream>, ErrorCode> {
235        match &self.tcp_state {
236            TcpState::Default(socket)
237            | TcpState::BindStarted(socket)
238            | TcpState::Bound(socket)
239            | TcpState::ListenStarted(socket) => Ok(socket.as_socketlike_view()),
240            TcpState::Connected(stream) => Ok(stream.as_socketlike_view()),
241            #[cfg(feature = "p3")]
242            TcpState::Receiving(stream) => Ok(stream.as_socketlike_view()),
243            TcpState::Listening { listener, .. } => Ok(listener.as_socketlike_view()),
244            TcpState::P2Streaming(state) => Ok(state.stream.as_socketlike_view()),
245            TcpState::Connecting(..) | TcpState::ConnectReady(_) | TcpState::Closed => {
246                Err(ErrorCode::InvalidState)
247            }
248            #[cfg(feature = "p3")]
249            TcpState::Error(err) => Err(err.into()),
250        }
251    }
252
253    pub(crate) fn start_bind(&mut self, addr: SocketAddr) -> Result<(), ErrorCode> {
254        let ip = addr.ip();
255        if !is_valid_unicast_address(ip) || !is_valid_address_family(ip, self.family) {
256            return Err(ErrorCode::InvalidArgument);
257        }
258        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
259            TcpState::Default(sock) => {
260                if let Err(err) = tcp_bind(&sock, addr) {
261                    self.tcp_state = TcpState::Default(sock);
262                    Err(err)
263                } else {
264                    self.tcp_state = TcpState::BindStarted(sock);
265                    Ok(())
266                }
267            }
268            tcp_state => {
269                self.tcp_state = tcp_state;
270                Err(ErrorCode::InvalidState)
271            }
272        }
273    }
274
275    pub(crate) fn finish_bind(&mut self) -> Result<(), ErrorCode> {
276        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
277            TcpState::BindStarted(socket) => {
278                self.tcp_state = TcpState::Bound(socket);
279                Ok(())
280            }
281            current_state => {
282                // Reset the state so that the outside world doesn't see this socket as closed
283                self.tcp_state = current_state;
284                Err(ErrorCode::NotInProgress)
285            }
286        }
287    }
288
289    pub(crate) fn start_connect(
290        &mut self,
291        addr: &SocketAddr,
292    ) -> Result<tokio::net::TcpSocket, ErrorCode> {
293        match self.tcp_state {
294            TcpState::Default(..) | TcpState::Bound(..) => {}
295            TcpState::Connecting(..) => {
296                return Err(ErrorCode::ConcurrencyConflict);
297            }
298            _ => return Err(ErrorCode::InvalidState),
299        };
300
301        if !is_valid_unicast_address(addr.ip())
302            || !is_valid_remote_address(*addr)
303            || !is_valid_address_family(addr.ip(), self.family)
304        {
305            return Err(ErrorCode::InvalidArgument);
306        };
307
308        let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
309            mem::replace(&mut self.tcp_state, TcpState::Connecting(None))
310        else {
311            unreachable!();
312        };
313
314        Ok(tokio_socket)
315    }
316
317    /// For WASIp2 this is used to record the actual connection future as part
318    /// of `start_connect` within this socket state.
319    pub(crate) fn set_pending_connect(
320        &mut self,
321        future: impl Future<Output = io::Result<tokio::net::TcpStream>> + Send + 'static,
322    ) -> Result<(), ErrorCode> {
323        match &mut self.tcp_state {
324            TcpState::Connecting(slot @ None) => {
325                *slot = Some(Box::pin(future));
326                Ok(())
327            }
328            _ => Err(ErrorCode::InvalidState),
329        }
330    }
331
332    /// For WASIp2 this retreives the result from the future passed to
333    /// `set_pending_connect`.
334    ///
335    /// Return states here are:
336    ///
337    /// * `Ok(Some(res))` - where `res` is the result of the connect operation.
338    /// * `Ok(None)` - the connect operation isn't ready yet.
339    /// * `Err(e)` - a connect operation is not in progress.
340    pub(crate) fn take_pending_connect(
341        &mut self,
342    ) -> Result<Option<io::Result<tokio::net::TcpStream>>, ErrorCode> {
343        match mem::replace(&mut self.tcp_state, TcpState::Connecting(None)) {
344            TcpState::ConnectReady(result) => Ok(Some(result)),
345            TcpState::Connecting(Some(mut future)) => {
346                let mut cx = Context::from_waker(Waker::noop());
347                match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
348                    Poll::Ready(result) => Ok(Some(result)),
349                    Poll::Pending => {
350                        self.tcp_state = TcpState::Connecting(Some(future));
351                        Ok(None)
352                    }
353                }
354            }
355            current_state => {
356                self.tcp_state = current_state;
357                Err(ErrorCode::NotInProgress)
358            }
359        }
360    }
361
362    pub(crate) fn finish_connect(
363        &mut self,
364        result: io::Result<tokio::net::TcpStream>,
365    ) -> Result<(), ErrorCode> {
366        if !matches!(self.tcp_state, TcpState::Connecting(None)) {
367            return Err(ErrorCode::InvalidState);
368        }
369        match result {
370            Ok(stream) => {
371                self.tcp_state = TcpState::Connected(Arc::new(stream));
372                Ok(())
373            }
374            Err(err) => {
375                self.tcp_state = TcpState::Closed;
376                Err(ErrorCode::from(err))
377            }
378        }
379    }
380
381    pub(crate) fn start_listen(&mut self) -> Result<(), ErrorCode> {
382        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
383            TcpState::Bound(tokio_socket) => {
384                self.tcp_state = TcpState::ListenStarted(tokio_socket);
385                Ok(())
386            }
387            previous_state => {
388                self.tcp_state = previous_state;
389                Err(ErrorCode::InvalidState)
390            }
391        }
392    }
393
394    pub(crate) fn finish_listen(&mut self) -> Result<(), ErrorCode> {
395        let tokio_socket = match mem::replace(&mut self.tcp_state, TcpState::Closed) {
396            TcpState::ListenStarted(tokio_socket) => tokio_socket,
397            previous_state => {
398                self.tcp_state = previous_state;
399                return Err(ErrorCode::NotInProgress);
400            }
401        };
402
403        match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
404            Ok(listener) => {
405                self.tcp_state = TcpState::Listening {
406                    listener: Arc::new(listener),
407                    pending_accept: None,
408                };
409                Ok(())
410            }
411            Err(err) => {
412                self.tcp_state = TcpState::Closed;
413
414                Err(match Errno::from_io_error(&err) {
415                    // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
416                    // According to the docs, `listen` can return EMFILE on Windows.
417                    // This is odd, because we're not trying to create a new socket
418                    // or file descriptor of any kind. So we rewrite it to less
419                    // surprising error code.
420                    //
421                    // At the time of writing, this behavior has never been experimentally
422                    // observed by any of the wasmtime authors, so we're relying fully
423                    // on Microsoft's documentation here.
424                    #[cfg(windows)]
425                    Some(Errno::MFILE) => Errno::NOBUFS.into(),
426
427                    _ => err.into(),
428                })
429            }
430        }
431    }
432
433    pub(crate) fn accept(&mut self) -> Result<Option<Self>, ErrorCode> {
434        let TcpState::Listening {
435            listener,
436            pending_accept,
437        } = &mut self.tcp_state
438        else {
439            return Err(ErrorCode::InvalidState);
440        };
441
442        let result = match pending_accept.take() {
443            Some(result) => result,
444            None => {
445                let mut cx = std::task::Context::from_waker(Waker::noop());
446                match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
447                    .map_ok(|(stream, _)| stream)
448                {
449                    Poll::Ready(result) => result,
450                    Poll::Pending => return Ok(None),
451                }
452            }
453        };
454
455        Ok(Some(Self::new_accept(result, &self.options, self.family)?))
456    }
457
458    #[cfg(feature = "p3")]
459    pub(crate) fn start_receive(&mut self) -> Option<&Arc<tokio::net::TcpStream>> {
460        match mem::replace(&mut self.tcp_state, TcpState::Closed) {
461            TcpState::Connected(stream) => {
462                self.tcp_state = TcpState::Receiving(stream);
463                Some(self.tcp_stream_arc().unwrap())
464            }
465            prev => {
466                self.tcp_state = prev;
467                None
468            }
469        }
470    }
471
472    pub(crate) fn local_address(&self) -> Result<SocketAddr, ErrorCode> {
473        match &self.tcp_state {
474            TcpState::Bound(socket) => Ok(socket.local_addr()?),
475            TcpState::Connected(stream) => Ok(stream.local_addr()?),
476            #[cfg(feature = "p3")]
477            TcpState::Receiving(stream) => Ok(stream.local_addr()?),
478            TcpState::P2Streaming(state) => Ok(state.stream.local_addr()?),
479            TcpState::Listening { listener, .. } => Ok(listener.local_addr()?),
480            #[cfg(feature = "p3")]
481            TcpState::Error(err) => Err(err.into()),
482            _ => Err(ErrorCode::InvalidState),
483        }
484    }
485
486    pub(crate) fn remote_address(&self) -> Result<SocketAddr, ErrorCode> {
487        let stream = self.tcp_stream_arc()?;
488        let addr = stream.peer_addr()?;
489        Ok(addr)
490    }
491
492    pub(crate) fn is_listening(&self) -> bool {
493        matches!(self.tcp_state, TcpState::Listening { .. })
494    }
495
496    pub(crate) fn address_family(&self) -> SocketAddressFamily {
497        self.family
498    }
499
500    pub(crate) fn set_listen_backlog_size(&mut self, value: u64) -> Result<(), ErrorCode> {
501        const MIN_BACKLOG: u32 = 1;
502        const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
503
504        if value == 0 {
505            return Err(ErrorCode::InvalidArgument);
506        }
507        // Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
508        let value = value
509            .try_into()
510            .unwrap_or(MAX_BACKLOG)
511            .clamp(MIN_BACKLOG, MAX_BACKLOG);
512        match &self.tcp_state {
513            TcpState::Default(..) | TcpState::Bound(..) => {
514                // Socket not listening yet. Stash value for first invocation to `listen`.
515                self.listen_backlog_size = value;
516                Ok(())
517            }
518            TcpState::Listening { listener, .. } => {
519                // Try to update the backlog by calling `listen` again.
520                // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
521                if rustix::net::listen(&listener, value.try_into().unwrap_or(i32::MAX)).is_err() {
522                    return Err(ErrorCode::NotSupported);
523                }
524                self.listen_backlog_size = value;
525                Ok(())
526            }
527            #[cfg(feature = "p3")]
528            TcpState::Error(err) => Err(err.into()),
529            _ => Err(ErrorCode::InvalidState),
530        }
531    }
532
533    pub(crate) fn keep_alive_enabled(&self) -> Result<bool, ErrorCode> {
534        let fd = &*self.as_std_view()?;
535        let v = sockopt::socket_keepalive(fd)?;
536        Ok(v)
537    }
538
539    pub(crate) fn set_keep_alive_enabled(&self, value: bool) -> Result<(), ErrorCode> {
540        let fd = &*self.as_std_view()?;
541        sockopt::set_socket_keepalive(fd, value)?;
542        Ok(())
543    }
544
545    pub(crate) fn keep_alive_idle_time(&self) -> Result<u64, ErrorCode> {
546        let fd = &*self.as_std_view()?;
547        let v = sockopt::tcp_keepidle(fd)?;
548        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
549    }
550
551    pub(crate) fn set_keep_alive_idle_time(&mut self, value: u64) -> Result<(), ErrorCode> {
552        let value = {
553            let fd = self.as_std_view()?;
554            set_keep_alive_idle_time(&*fd, value)?
555        };
556        self.options.set_keep_alive_idle_time(value);
557        Ok(())
558    }
559
560    pub(crate) fn keep_alive_interval(&self) -> Result<u64, ErrorCode> {
561        let fd = &*self.as_std_view()?;
562        let v = sockopt::tcp_keepintvl(fd)?;
563        Ok(v.as_nanos().try_into().unwrap_or(u64::MAX))
564    }
565
566    pub(crate) fn set_keep_alive_interval(&self, value: u64) -> Result<(), ErrorCode> {
567        let fd = &*self.as_std_view()?;
568        set_keep_alive_interval(fd, Duration::from_nanos(value))?;
569        Ok(())
570    }
571
572    pub(crate) fn keep_alive_count(&self) -> Result<u32, ErrorCode> {
573        let fd = &*self.as_std_view()?;
574        let v = sockopt::tcp_keepcnt(fd)?;
575        Ok(v)
576    }
577
578    pub(crate) fn set_keep_alive_count(&self, value: u32) -> Result<(), ErrorCode> {
579        let fd = &*self.as_std_view()?;
580        set_keep_alive_count(fd, value)?;
581        Ok(())
582    }
583
584    pub(crate) fn hop_limit(&self) -> Result<u8, ErrorCode> {
585        let fd = &*self.as_std_view()?;
586        let n = get_unicast_hop_limit(fd, self.family)?;
587        Ok(n)
588    }
589
590    pub(crate) fn set_hop_limit(&mut self, value: u8) -> Result<(), ErrorCode> {
591        {
592            let fd = &*self.as_std_view()?;
593            set_unicast_hop_limit(fd, self.family, value)?;
594        }
595        self.options.set_hop_limit(value);
596        Ok(())
597    }
598
599    pub(crate) fn receive_buffer_size(&self) -> Result<u64, ErrorCode> {
600        let fd = &*self.as_std_view()?;
601        let n = receive_buffer_size(fd)?;
602        Ok(n)
603    }
604
605    pub(crate) fn set_receive_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
606        let res = {
607            let fd = &*self.as_std_view()?;
608            set_receive_buffer_size(fd, value)?
609        };
610        self.options.set_receive_buffer_size(res);
611        Ok(())
612    }
613
614    pub(crate) fn send_buffer_size(&self) -> Result<u64, ErrorCode> {
615        let fd = &*self.as_std_view()?;
616        let n = send_buffer_size(fd)?;
617        Ok(n)
618    }
619
620    pub(crate) fn set_send_buffer_size(&mut self, value: u64) -> Result<(), ErrorCode> {
621        let res = {
622            let fd = &*self.as_std_view()?;
623            set_send_buffer_size(fd, value)?
624        };
625        self.options.set_send_buffer_size(res);
626        Ok(())
627    }
628
629    #[cfg(feature = "p3")]
630    pub(crate) fn non_inherited_options(&self) -> &NonInheritedOptions {
631        &self.options
632    }
633
634    #[cfg(feature = "p3")]
635    pub(crate) fn tcp_listener_arc(&self) -> Result<&Arc<tokio::net::TcpListener>, ErrorCode> {
636        match &self.tcp_state {
637            TcpState::Listening { listener, .. } => Ok(listener),
638            #[cfg(feature = "p3")]
639            TcpState::Error(err) => Err(err.into()),
640            _ => Err(ErrorCode::InvalidState),
641        }
642    }
643
644    pub(crate) fn tcp_stream_arc(&self) -> Result<&Arc<tokio::net::TcpStream>, ErrorCode> {
645        match &self.tcp_state {
646            TcpState::Connected(socket) => Ok(socket),
647            #[cfg(feature = "p3")]
648            TcpState::Receiving(socket) => Ok(socket),
649            TcpState::P2Streaming(state) => Ok(&state.stream),
650            #[cfg(feature = "p3")]
651            TcpState::Error(err) => Err(err.into()),
652            _ => Err(ErrorCode::InvalidState),
653        }
654    }
655
656    pub(crate) fn p2_streaming_state(&self) -> Result<&P2TcpStreamingState, ErrorCode> {
657        match &self.tcp_state {
658            TcpState::P2Streaming(state) => Ok(state),
659            #[cfg(feature = "p3")]
660            TcpState::Error(err) => Err(err.into()),
661            _ => Err(ErrorCode::InvalidState),
662        }
663    }
664
665    pub(crate) fn set_p2_streaming_state(
666        &mut self,
667        state: P2TcpStreamingState,
668    ) -> Result<(), ErrorCode> {
669        if !matches!(self.tcp_state, TcpState::Connected(_)) {
670            return Err(ErrorCode::InvalidState);
671        }
672        self.tcp_state = TcpState::P2Streaming(Box::new(state));
673        Ok(())
674    }
675
676    /// Used for `Pollable` in the WASIp2 implementation this awaits the socket
677    /// to be connected, if in the connecting state, or for a TCP accept to be
678    /// ready, if this is in the listening state.
679    ///
680    /// For all other states this method immediately returns.
681    pub(crate) async fn ready(&mut self) {
682        match &mut self.tcp_state {
683            TcpState::Default(..)
684            | TcpState::BindStarted(..)
685            | TcpState::Bound(..)
686            | TcpState::ListenStarted(..)
687            | TcpState::ConnectReady(..)
688            | TcpState::Closed
689            | TcpState::Connected { .. }
690            | TcpState::Connecting(None)
691            | TcpState::Listening {
692                pending_accept: Some(_),
693                ..
694            }
695            | TcpState::P2Streaming(_) => {}
696
697            #[cfg(feature = "p3")]
698            TcpState::Receiving(_) | TcpState::Error(_) => {}
699
700            TcpState::Connecting(Some(future)) => {
701                self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
702            }
703
704            TcpState::Listening {
705                listener,
706                pending_accept: slot @ None,
707            } => {
708                let result = futures::future::poll_fn(|cx| {
709                    listener.poll_accept(cx).map_ok(|(stream, _)| stream)
710                })
711                .await;
712                *slot = Some(result);
713            }
714        }
715    }
716}
717
718#[cfg(not(target_os = "macos"))]
719pub use inherits_option::*;
720#[cfg(not(target_os = "macos"))]
721mod inherits_option {
722    use crate::sockets::SocketAddressFamily;
723    use tokio::net::TcpStream;
724
725    #[derive(Default, Clone)]
726    pub struct NonInheritedOptions;
727
728    impl NonInheritedOptions {
729        pub fn set_keep_alive_idle_time(&mut self, _value: u64) {}
730
731        pub fn set_hop_limit(&mut self, _value: u8) {}
732
733        pub fn set_receive_buffer_size(&mut self, _value: usize) {}
734
735        pub fn set_send_buffer_size(&mut self, _value: usize) {}
736
737        pub(crate) fn apply(&self, _family: SocketAddressFamily, _stream: &TcpStream) {}
738    }
739}
740
741#[cfg(target_os = "macos")]
742pub use does_not_inherit_options::*;
743#[cfg(target_os = "macos")]
744mod does_not_inherit_options {
745    use crate::sockets::SocketAddressFamily;
746    use rustix::net::sockopt;
747    use std::sync::Arc;
748    use std::sync::atomic::{AtomicU8, AtomicU64, AtomicUsize, Ordering::Relaxed};
749    use std::time::Duration;
750    use tokio::net::TcpStream;
751
752    // The socket options below are not automatically inherited from the listener
753    // on all platforms. So we keep track of which options have been explicitly
754    // set and manually apply those values to newly accepted clients.
755    #[derive(Default, Clone)]
756    pub struct NonInheritedOptions(Arc<Inner>);
757
758    #[derive(Default)]
759    struct Inner {
760        receive_buffer_size: AtomicUsize,
761        send_buffer_size: AtomicUsize,
762        hop_limit: AtomicU8,
763        keep_alive_idle_time: AtomicU64, // nanoseconds
764    }
765
766    impl NonInheritedOptions {
767        pub fn set_keep_alive_idle_time(&mut self, value: u64) {
768            self.0.keep_alive_idle_time.store(value, Relaxed);
769        }
770
771        pub fn set_hop_limit(&mut self, value: u8) {
772            self.0.hop_limit.store(value, Relaxed);
773        }
774
775        pub fn set_receive_buffer_size(&mut self, value: usize) {
776            self.0.receive_buffer_size.store(value, Relaxed);
777        }
778
779        pub fn set_send_buffer_size(&mut self, value: usize) {
780            self.0.send_buffer_size.store(value, Relaxed);
781        }
782
783        pub(crate) fn apply(&self, family: SocketAddressFamily, stream: &TcpStream) {
784            // Manually inherit socket options from listener. We only have to
785            // do this on platforms that don't already do this automatically
786            // and only if a specific value was explicitly set on the listener.
787
788            let receive_buffer_size = self.0.receive_buffer_size.load(Relaxed);
789            if receive_buffer_size > 0 {
790                // Ignore potential error.
791                _ = sockopt::set_socket_recv_buffer_size(&stream, receive_buffer_size);
792            }
793
794            let send_buffer_size = self.0.send_buffer_size.load(Relaxed);
795            if send_buffer_size > 0 {
796                // Ignore potential error.
797                _ = sockopt::set_socket_send_buffer_size(&stream, send_buffer_size);
798            }
799
800            // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
801            if family == SocketAddressFamily::Ipv6 {
802                let hop_limit = self.0.hop_limit.load(Relaxed);
803                if hop_limit > 0 {
804                    // Ignore potential error.
805                    _ = sockopt::set_ipv6_unicast_hops(&stream, Some(hop_limit));
806                }
807            }
808
809            let keep_alive_idle_time = self.0.keep_alive_idle_time.load(Relaxed);
810            if keep_alive_idle_time > 0 {
811                // Ignore potential error.
812                _ = sockopt::set_tcp_keepidle(&stream, Duration::from_nanos(keep_alive_idle_time));
813            }
814        }
815    }
816}