1use crate::net::{SocketAddressFamily, DEFAULT_TCP_BACKLOG};
2use crate::p2::bindings::sockets::tcp::ErrorCode;
3use crate::p2::host::network;
4use crate::p2::{
5 DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
6 SocketResult, StreamError,
7};
8use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
9use anyhow::Result;
10use cap_net_ext::AddressFamily;
11use futures::Future;
12use io_lifetimes::views::SocketlikeView;
13use io_lifetimes::AsSocketlike;
14use rustix::io::Errno;
15use rustix::net::sockopt;
16use std::io;
17use std::mem;
18use std::net::{Shutdown, SocketAddr};
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Poll;
22use tokio::sync::Mutex;
23
24enum TcpState {
29 Default(tokio::net::TcpSocket),
31
32 BindStarted(tokio::net::TcpSocket),
34
35 Bound(tokio::net::TcpSocket),
38
39 ListenStarted(tokio::net::TcpSocket),
41
42 Listening {
44 listener: tokio::net::TcpListener,
45 pending_accept: Option<io::Result<tokio::net::TcpStream>>,
46 },
47
48 Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
50
51 ConnectReady(io::Result<tokio::net::TcpStream>),
53
54 Connected {
56 stream: Arc<tokio::net::TcpStream>,
57
58 reader: Arc<Mutex<TcpReader>>,
60 writer: Arc<Mutex<TcpWriter>>,
61 },
62
63 Closed,
64}
65
66impl std::fmt::Debug for TcpState {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 Self::Default(_) => f.debug_tuple("Default").finish(),
70 Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
71 Self::Bound(_) => f.debug_tuple("Bound").finish(),
72 Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
73 Self::Listening { pending_accept, .. } => f
74 .debug_struct("Listening")
75 .field("pending_accept", pending_accept)
76 .finish(),
77 Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
78 Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
79 Self::Connected { .. } => f.debug_tuple("Connected").finish(),
80 Self::Closed => write!(f, "Closed"),
81 }
82 }
83}
84
85pub struct TcpSocket {
87 tcp_state: TcpState,
89
90 listen_backlog_size: u32,
92
93 family: SocketAddressFamily,
94
95 #[cfg(target_os = "macos")]
99 receive_buffer_size: Option<usize>,
100 #[cfg(target_os = "macos")]
101 send_buffer_size: Option<usize>,
102 #[cfg(target_os = "macos")]
103 hop_limit: Option<u8>,
104 #[cfg(target_os = "macos")]
105 keep_alive_idle_time: Option<std::time::Duration>,
106}
107
108impl TcpSocket {
109 pub fn new(family: AddressFamily) -> io::Result<Self> {
111 with_ambient_tokio_runtime(|| {
112 let (socket, family) = match family {
113 AddressFamily::Ipv4 => {
114 let socket = tokio::net::TcpSocket::new_v4()?;
115 (socket, SocketAddressFamily::Ipv4)
116 }
117 AddressFamily::Ipv6 => {
118 let socket = tokio::net::TcpSocket::new_v6()?;
119 sockopt::set_ipv6_v6only(&socket, true)?;
120 (socket, SocketAddressFamily::Ipv6)
121 }
122 };
123
124 Self::from_state(TcpState::Default(socket), family)
125 })
126 }
127
128 fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
130 Ok(Self {
131 tcp_state: state,
132 listen_backlog_size: DEFAULT_TCP_BACKLOG,
133 family,
134 #[cfg(target_os = "macos")]
135 receive_buffer_size: None,
136 #[cfg(target_os = "macos")]
137 send_buffer_size: None,
138 #[cfg(target_os = "macos")]
139 hop_limit: None,
140 #[cfg(target_os = "macos")]
141 keep_alive_idle_time: None,
142 })
143 }
144
145 fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
146 use crate::p2::bindings::sockets::network::ErrorCode;
147
148 match &self.tcp_state {
149 TcpState::Default(socket) | TcpState::Bound(socket) => {
150 Ok(socket.as_socketlike_view::<std::net::TcpStream>())
151 }
152 TcpState::Connected { stream, .. } => {
153 Ok(stream.as_socketlike_view::<std::net::TcpStream>())
154 }
155 TcpState::Listening { listener, .. } => {
156 Ok(listener.as_socketlike_view::<std::net::TcpStream>())
157 }
158
159 TcpState::BindStarted(..)
160 | TcpState::ListenStarted(..)
161 | TcpState::Connecting(..)
162 | TcpState::ConnectReady(..)
163 | TcpState::Closed => Err(ErrorCode::InvalidState.into()),
164 }
165 }
166}
167
168impl TcpSocket {
169 pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
170 let tokio_socket = match &self.tcp_state {
171 TcpState::Default(socket) => socket,
172 TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
173 _ => return Err(Errno::ISCONN.into()),
174 };
175
176 network::util::validate_unicast(&local_address)?;
177 network::util::validate_address_family(&local_address, &self.family)?;
178
179 {
180 let reuse_addr = local_address.port() > 0;
183
184 network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
188
189 tokio_socket.bind(local_address).map_err(|error| {
191 match Errno::from_io_error(&error) {
192 Some(Errno::AFNOSUPPORT) => io::Error::new(
200 io::ErrorKind::InvalidInput,
201 "The specified address is not a valid address for the address family of the specified socket",
202 ),
203
204 #[cfg(windows)]
207 Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
208
209 _ => error,
210 }
211 })?;
212
213 self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
214 TcpState::Default(socket) => TcpState::BindStarted(socket),
215 _ => unreachable!(),
216 };
217
218 Ok(())
219 }
220 }
221
222 pub fn finish_bind(&mut self) -> SocketResult<()> {
223 match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
224 TcpState::BindStarted(socket) => {
225 self.tcp_state = TcpState::Bound(socket);
226 Ok(())
227 }
228 current_state => {
229 self.tcp_state = current_state;
231 Err(ErrorCode::NotInProgress.into())
232 }
233 }
234 }
235
236 pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
237 match self.tcp_state {
238 TcpState::Default(..) | TcpState::Bound(..) => {}
239
240 TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
241 return Err(ErrorCode::ConcurrencyConflict.into())
242 }
243
244 _ => return Err(ErrorCode::InvalidState.into()),
245 };
246
247 network::util::validate_unicast(&remote_address)?;
248 network::util::validate_remote_address(&remote_address)?;
249 network::util::validate_address_family(&remote_address, &self.family)?;
250
251 let (TcpState::Default(tokio_socket) | TcpState::Bound(tokio_socket)) =
252 std::mem::replace(&mut self.tcp_state, TcpState::Closed)
253 else {
254 unreachable!();
255 };
256
257 let future = tokio_socket.connect(remote_address);
258
259 self.tcp_state = TcpState::Connecting(Box::pin(future));
260 Ok(())
261 }
262
263 pub fn finish_connect(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
264 let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
265 let result = match previous_state {
266 TcpState::ConnectReady(result) => result,
267 TcpState::Connecting(mut future) => {
268 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
269 match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
270 Poll::Ready(result) => result,
271 Poll::Pending => {
272 self.tcp_state = TcpState::Connecting(future);
273 return Err(ErrorCode::WouldBlock.into());
274 }
275 }
276 }
277 previous_state => {
278 self.tcp_state = previous_state;
279 return Err(ErrorCode::NotInProgress.into());
280 }
281 };
282
283 match result {
284 Ok(stream) => {
285 let stream = Arc::new(stream);
286 let reader = Arc::new(Mutex::new(TcpReader::new(stream.clone())));
287 let writer = Arc::new(Mutex::new(TcpWriter::new(stream.clone())));
288 self.tcp_state = TcpState::Connected {
289 stream,
290 reader: reader.clone(),
291 writer: writer.clone(),
292 };
293 let input: DynInputStream = Box::new(TcpReadStream(reader));
294 let output: DynOutputStream = Box::new(TcpWriteStream(writer));
295 Ok((input, output))
296 }
297 Err(err) => {
298 self.tcp_state = TcpState::Closed;
299 Err(err.into())
300 }
301 }
302 }
303
304 pub fn start_listen(&mut self) -> SocketResult<()> {
305 match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
306 TcpState::Bound(tokio_socket) => {
307 self.tcp_state = TcpState::ListenStarted(tokio_socket);
308 Ok(())
309 }
310 TcpState::ListenStarted(tokio_socket) => {
311 self.tcp_state = TcpState::ListenStarted(tokio_socket);
312 Err(ErrorCode::ConcurrencyConflict.into())
313 }
314 previous_state => {
315 self.tcp_state = previous_state;
316 Err(ErrorCode::InvalidState.into())
317 }
318 }
319 }
320
321 pub fn finish_listen(&mut self) -> SocketResult<()> {
322 let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
323 TcpState::ListenStarted(tokio_socket) => tokio_socket,
324 previous_state => {
325 self.tcp_state = previous_state;
326 return Err(ErrorCode::NotInProgress.into());
327 }
328 };
329
330 match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
331 Ok(listener) => {
332 self.tcp_state = TcpState::Listening {
333 listener,
334 pending_accept: None,
335 };
336 Ok(())
337 }
338 Err(err) => {
339 self.tcp_state = TcpState::Closed;
340
341 Err(match Errno::from_io_error(&err) {
342 #[cfg(windows)]
352 Some(Errno::MFILE) => Errno::NOBUFS.into(),
353
354 _ => err.into(),
355 })
356 }
357 }
358 }
359
360 pub fn accept(&mut self) -> SocketResult<(Self, DynInputStream, DynOutputStream)> {
361 let TcpState::Listening {
362 listener,
363 pending_accept,
364 } = &mut self.tcp_state
365 else {
366 return Err(ErrorCode::InvalidState.into());
367 };
368
369 let result = match pending_accept.take() {
370 Some(result) => result,
371 None => {
372 let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
373 match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
374 .map_ok(|(stream, _)| stream)
375 {
376 Poll::Ready(result) => result,
377 Poll::Pending => Err(Errno::WOULDBLOCK.into()),
378 }
379 }
380 };
381
382 let client = result.map_err(|err| match Errno::from_io_error(&err) {
383 #[cfg(windows)]
391 Some(Errno::INPROGRESS) => Errno::INTR.into(),
392
393 #[cfg(target_os = "linux")]
400 Some(
401 Errno::CONNRESET
402 | Errno::NETRESET
403 | Errno::HOSTUNREACH
404 | Errno::HOSTDOWN
405 | Errno::NETDOWN
406 | Errno::NETUNREACH
407 | Errno::PROTO
408 | Errno::NOPROTOOPT
409 | Errno::NONET
410 | Errno::OPNOTSUPP,
411 ) => Errno::CONNABORTED.into(),
412
413 _ => err,
414 })?;
415
416 #[cfg(target_os = "macos")]
417 {
418 if let Some(size) = self.receive_buffer_size {
423 _ = network::util::set_socket_recv_buffer_size(&client, size); }
425
426 if let Some(size) = self.send_buffer_size {
427 _ = network::util::set_socket_send_buffer_size(&client, size); }
429
430 if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
432 _ = network::util::set_ipv6_unicast_hops(&client, ttl); }
434
435 if let Some(value) = self.keep_alive_idle_time {
436 _ = network::util::set_tcp_keepidle(&client, value); }
438 }
439
440 let client = Arc::new(client);
441
442 let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
443 let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
444
445 let input: DynInputStream = Box::new(TcpReadStream(reader.clone()));
446 let output: DynOutputStream = Box::new(TcpWriteStream(writer.clone()));
447 let tcp_socket = TcpSocket::from_state(
448 TcpState::Connected {
449 stream: client,
450 reader,
451 writer,
452 },
453 self.family,
454 )?;
455
456 Ok((tcp_socket, input, output))
457 }
458
459 pub fn local_address(&self) -> SocketResult<SocketAddr> {
460 let view = match self.tcp_state {
461 TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
462 TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
463 _ => self.as_std_view()?,
464 };
465
466 Ok(view.local_addr()?)
467 }
468
469 pub fn remote_address(&self) -> SocketResult<SocketAddr> {
470 let view = match self.tcp_state {
471 TcpState::Connected { .. } => self.as_std_view()?,
472 TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
473 return Err(ErrorCode::ConcurrencyConflict.into())
474 }
475 _ => return Err(ErrorCode::InvalidState.into()),
476 };
477
478 Ok(view.peer_addr()?)
479 }
480
481 pub fn is_listening(&self) -> bool {
482 matches!(self.tcp_state, TcpState::Listening { .. })
483 }
484
485 pub fn address_family(&self) -> SocketAddressFamily {
486 self.family
487 }
488
489 pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
490 const MIN_BACKLOG: u32 = 1;
491 const MAX_BACKLOG: u32 = i32::MAX as u32; if value == 0 {
494 return Err(ErrorCode::InvalidArgument.into());
495 }
496
497 let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
499
500 match &self.tcp_state {
501 TcpState::Default(..) | TcpState::Bound(..) => {
502 }
504 TcpState::Listening { listener, .. } => {
505 rustix::net::listen(&listener, value.try_into().unwrap())
509 .map_err(|_| ErrorCode::NotSupported)?;
510 }
511 _ => return Err(ErrorCode::InvalidState.into()),
512 }
513 self.listen_backlog_size = value;
514
515 Ok(())
516 }
517
518 pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
519 let view = &*self.as_std_view()?;
520 Ok(sockopt::socket_keepalive(view)?)
521 }
522
523 pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
524 let view = &*self.as_std_view()?;
525 Ok(sockopt::set_socket_keepalive(view, value)?)
526 }
527
528 pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
529 let view = &*self.as_std_view()?;
530 Ok(sockopt::tcp_keepidle(view)?)
531 }
532
533 pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
534 {
535 let view = &*self.as_std_view()?;
536 network::util::set_tcp_keepidle(view, duration)?;
537 }
538
539 #[cfg(target_os = "macos")]
540 {
541 self.keep_alive_idle_time = Some(duration);
542 }
543
544 Ok(())
545 }
546
547 pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
548 let view = &*self.as_std_view()?;
549 Ok(sockopt::tcp_keepintvl(view)?)
550 }
551
552 pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
553 let view = &*self.as_std_view()?;
554 Ok(network::util::set_tcp_keepintvl(view, duration)?)
555 }
556
557 pub fn keep_alive_count(&self) -> SocketResult<u32> {
558 let view = &*self.as_std_view()?;
559 Ok(sockopt::tcp_keepcnt(view)?)
560 }
561
562 pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
563 let view = &*self.as_std_view()?;
564 Ok(network::util::set_tcp_keepcnt(view, value)?)
565 }
566
567 pub fn hop_limit(&self) -> SocketResult<u8> {
568 let view = &*self.as_std_view()?;
569
570 let ttl = match self.family {
571 SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
572 SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
573 };
574
575 Ok(ttl)
576 }
577
578 pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
579 {
580 let view = &*self.as_std_view()?;
581
582 match self.family {
583 SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
584 SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
585 }
586 }
587
588 #[cfg(target_os = "macos")]
589 {
590 self.hop_limit = Some(value);
591 }
592
593 Ok(())
594 }
595
596 pub fn receive_buffer_size(&self) -> SocketResult<usize> {
597 let view = &*self.as_std_view()?;
598
599 Ok(network::util::get_socket_recv_buffer_size(view)?)
600 }
601
602 pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
603 {
604 let view = &*self.as_std_view()?;
605
606 network::util::set_socket_recv_buffer_size(view, value)?;
607 }
608
609 #[cfg(target_os = "macos")]
610 {
611 self.receive_buffer_size = Some(value);
612 }
613
614 Ok(())
615 }
616
617 pub fn send_buffer_size(&self) -> SocketResult<usize> {
618 let view = &*self.as_std_view()?;
619
620 Ok(network::util::get_socket_send_buffer_size(view)?)
621 }
622
623 pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
624 {
625 let view = &*self.as_std_view()?;
626
627 network::util::set_socket_send_buffer_size(view, value)?;
628 }
629
630 #[cfg(target_os = "macos")]
631 {
632 self.send_buffer_size = Some(value);
633 }
634
635 Ok(())
636 }
637
638 pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
639 let TcpState::Connected { reader, writer, .. } = &self.tcp_state else {
640 return Err(ErrorCode::InvalidState.into());
641 };
642
643 if let Shutdown::Both | Shutdown::Read = how {
644 try_lock_for_socket(reader)?.shutdown();
645 }
646
647 if let Shutdown::Both | Shutdown::Write = how {
648 try_lock_for_socket(writer)?.shutdown();
649 }
650
651 Ok(())
652 }
653}
654
655#[async_trait::async_trait]
656impl Pollable for TcpSocket {
657 async fn ready(&mut self) {
658 match &mut self.tcp_state {
659 TcpState::Default(..)
660 | TcpState::BindStarted(..)
661 | TcpState::Bound(..)
662 | TcpState::ListenStarted(..)
663 | TcpState::ConnectReady(..)
664 | TcpState::Closed
665 | TcpState::Connected { .. } => {
666 }
668 TcpState::Connecting(future) => {
669 self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
670 }
671 TcpState::Listening {
672 listener,
673 pending_accept,
674 } => match pending_accept {
675 Some(_) => {}
676 None => {
677 let result = futures::future::poll_fn(|cx| {
678 listener.poll_accept(cx).map_ok(|(stream, _)| stream)
679 })
680 .await;
681 *pending_accept = Some(result);
682 }
683 },
684 }
685 }
686}
687
688struct TcpReader {
689 stream: Arc<tokio::net::TcpStream>,
690 closed: bool,
691}
692
693impl TcpReader {
694 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
695 Self {
696 stream,
697 closed: false,
698 }
699 }
700 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
701 if self.closed {
702 return Err(StreamError::Closed);
703 }
704 if size == 0 {
705 return Ok(bytes::Bytes::new());
706 }
707
708 let mut buf = bytes::BytesMut::with_capacity(size);
709 let n = match self.stream.try_read_buf(&mut buf) {
710 Ok(0) => {
712 self.closed = true;
713 return Err(StreamError::Closed);
714 }
715 Ok(n) => n,
716
717 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
720
721 Err(e) => {
722 self.closed = true;
723 return Err(StreamError::LastOperationFailed(e.into()));
724 }
725 };
726
727 buf.truncate(n);
728 Ok(buf.freeze())
729 }
730
731 fn shutdown(&mut self) {
732 native_shutdown(&self.stream, Shutdown::Read);
733 self.closed = true;
734 }
735
736 async fn ready(&mut self) {
737 if self.closed {
738 return;
739 }
740
741 self.stream.readable().await.unwrap();
742 }
743}
744
745struct TcpReadStream(Arc<Mutex<TcpReader>>);
746
747#[async_trait::async_trait]
748impl InputStream for TcpReadStream {
749 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
750 try_lock_for_stream(&self.0)?.read(size)
751 }
752}
753
754#[async_trait::async_trait]
755impl Pollable for TcpReadStream {
756 async fn ready(&mut self) {
757 self.0.lock().await.ready().await
758 }
759}
760
761const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
762
763struct TcpWriter {
764 stream: Arc<tokio::net::TcpStream>,
765 state: WriteState,
766}
767
768enum WriteState {
769 Ready,
770 Writing(AbortOnDropJoinHandle<io::Result<()>>),
771 Closing(AbortOnDropJoinHandle<io::Result<()>>),
772 Closed,
773 Error(io::Error),
774}
775
776impl TcpWriter {
777 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
778 Self {
779 stream,
780 state: WriteState::Ready,
781 }
782 }
783
784 fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
785 stream.try_write(buf).map_err(|error| {
786 match Errno::from_io_error(&error) {
787 #[cfg(windows)]
791 Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
792
793 _ => error,
794 }
795 })
796 }
797
798 fn background_write(&mut self, mut bytes: bytes::Bytes) {
801 assert!(matches!(self.state, WriteState::Ready));
802
803 let stream = self.stream.clone();
804 self.state = WriteState::Writing(crate::runtime::spawn(async move {
805 while !bytes.is_empty() {
811 stream.writable().await?;
812 match Self::try_write_portable(&stream, &bytes) {
813 Ok(n) => {
814 let _ = bytes.split_to(n);
815 }
816 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
817 Err(e) => return Err(e.into()),
818 }
819 }
820
821 Ok(())
822 }));
823 }
824
825 fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
826 match self.state {
827 WriteState::Ready => {}
828 WriteState::Closed => return Err(StreamError::Closed),
829 WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
830 return Err(StreamError::Trap(anyhow::anyhow!(
831 "unpermitted: must call check_write first"
832 )));
833 }
834 }
835 while !bytes.is_empty() {
836 match Self::try_write_portable(&self.stream, &bytes) {
837 Ok(n) => {
838 let _ = bytes.split_to(n);
839 }
840
841 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
842 self.background_write(bytes);
845
846 return Ok(());
847 }
848
849 Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
850 self.state = WriteState::Closed;
851 return Err(StreamError::Closed);
852 }
853
854 Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
855 }
856 }
857
858 Ok(())
859 }
860
861 fn flush(&mut self) -> Result<(), StreamError> {
862 match self.state {
866 WriteState::Ready
867 | WriteState::Writing(_)
868 | WriteState::Closing(_)
869 | WriteState::Error(_) => Ok(()),
870 WriteState::Closed => Err(StreamError::Closed),
871 }
872 }
873
874 fn check_write(&mut self) -> Result<usize, StreamError> {
875 match mem::replace(&mut self.state, WriteState::Closed) {
876 WriteState::Writing(task) => {
877 self.state = WriteState::Writing(task);
878 return Ok(0);
879 }
880 WriteState::Closing(task) => {
881 self.state = WriteState::Closing(task);
882 return Ok(0);
883 }
884 WriteState::Ready => {
885 self.state = WriteState::Ready;
886 }
887 WriteState::Closed => return Err(StreamError::Closed),
888 WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
889 }
890
891 let writable = self.stream.writable();
892 futures::pin_mut!(writable);
893 if crate::runtime::poll_noop(writable).is_none() {
894 return Ok(0);
895 }
896 Ok(SOCKET_READY_SIZE)
897 }
898
899 fn shutdown(&mut self) {
900 self.state = match mem::replace(&mut self.state, WriteState::Closed) {
901 WriteState::Ready => {
903 native_shutdown(&self.stream, Shutdown::Write);
904 WriteState::Closed
905 }
906
907 WriteState::Writing(write) => {
909 let stream = self.stream.clone();
910 WriteState::Closing(crate::runtime::spawn(async move {
911 let result = write.await;
912 native_shutdown(&stream, Shutdown::Write);
913 result
914 }))
915 }
916
917 s => s,
918 };
919 }
920
921 async fn cancel(&mut self) {
922 match mem::replace(&mut self.state, WriteState::Closed) {
923 WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
924 _ => {}
925 }
926 }
927
928 async fn ready(&mut self) {
929 match &mut self.state {
930 WriteState::Writing(task) => {
931 self.state = match task.await {
932 Ok(()) => WriteState::Ready,
933 Err(e) => WriteState::Error(e),
934 }
935 }
936 WriteState::Closing(task) => {
937 self.state = match task.await {
938 Ok(()) => WriteState::Closed,
939 Err(e) => WriteState::Error(e),
940 }
941 }
942 _ => {}
943 }
944
945 if let WriteState::Ready = self.state {
946 self.stream.writable().await.unwrap();
947 }
948 }
949}
950
951struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
952
953#[async_trait::async_trait]
954impl OutputStream for TcpWriteStream {
955 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
956 try_lock_for_stream(&self.0)?.write(bytes)
957 }
958
959 fn flush(&mut self) -> Result<(), StreamError> {
960 try_lock_for_stream(&self.0)?.flush()
961 }
962
963 fn check_write(&mut self) -> Result<usize, StreamError> {
964 try_lock_for_stream(&self.0)?.check_write()
965 }
966
967 async fn cancel(&mut self) {
968 self.0.lock().await.cancel().await
969 }
970}
971
972#[async_trait::async_trait]
973impl Pollable for TcpWriteStream {
974 async fn ready(&mut self) {
975 self.0.lock().await.ready().await
976 }
977}
978
979fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
980 _ = stream
981 .as_socketlike_view::<std::net::TcpStream>()
982 .shutdown(how);
983}
984
985fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
986 mutex
987 .try_lock()
988 .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
989}
990
991fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, SocketError> {
992 mutex.try_lock().map_err(|_| {
993 SocketError::trap(anyhow::anyhow!(
994 "concurrent access to resource not supported"
995 ))
996 })
997}