1use std::{
2 collections::VecDeque,
3 net::{IpAddr, SocketAddr},
4 pin::{Pin, pin},
5 sync::{Arc, Mutex, MutexGuard},
6 task::{Context, Poll, Waker},
7 time::{Duration, Instant},
8};
9
10use compio_buf::{BufResult, bytes::Bytes};
11use compio_log::{Instrument, error};
12use compio_runtime::JoinHandle;
13use flume::{Receiver, Sender};
14use futures_util::{
15 Future, FutureExt, StreamExt,
16 future::{self, Fuse, FusedFuture, LocalBoxFuture},
17 select, stream,
18};
19#[cfg(rustls)]
20use quinn_proto::crypto::rustls::HandshakeData;
21use quinn_proto::{
22 ConnectionHandle, ConnectionStats, Dir, EndpointEvent, StreamEvent, StreamId, VarInt,
23 congestion::Controller,
24};
25use rustc_hash::FxHashMap as HashMap;
26use thiserror::Error;
27
28use crate::{RecvStream, SendStream, Socket};
29
30#[derive(Debug)]
31pub(crate) enum ConnectionEvent {
32 Close(VarInt, Bytes),
33 Proto(quinn_proto::ConnectionEvent),
34}
35
36#[derive(Debug)]
37pub(crate) struct ConnectionState {
38 pub(crate) conn: quinn_proto::Connection,
39 pub(crate) error: Option<ConnectionError>,
40 connected: bool,
41 worker: Option<JoinHandle<()>>,
42 poller: Option<Waker>,
43 on_connected: Option<Waker>,
44 on_handshake_data: Option<Waker>,
45 datagram_received: VecDeque<Waker>,
46 datagrams_unblocked: VecDeque<Waker>,
47 stream_opened: [VecDeque<Waker>; 2],
48 stream_available: [VecDeque<Waker>; 2],
49 pub(crate) writable: HashMap<StreamId, Waker>,
50 pub(crate) readable: HashMap<StreamId, Waker>,
51 pub(crate) stopped: HashMap<StreamId, Waker>,
52}
53
54impl ConnectionState {
55 fn terminate(&mut self, reason: ConnectionError) {
56 self.error = Some(reason);
57 self.connected = false;
58
59 if let Some(waker) = self.on_handshake_data.take() {
60 waker.wake()
61 }
62 if let Some(waker) = self.on_connected.take() {
63 waker.wake()
64 }
65 self.datagram_received.drain(..).for_each(Waker::wake);
66 self.datagrams_unblocked.drain(..).for_each(Waker::wake);
67 for e in &mut self.stream_opened {
68 e.drain(..).for_each(Waker::wake);
69 }
70 for e in &mut self.stream_available {
71 e.drain(..).for_each(Waker::wake);
72 }
73 wake_all_streams(&mut self.writable);
74 wake_all_streams(&mut self.readable);
75 wake_all_streams(&mut self.stopped);
76 }
77
78 fn close(&mut self, error_code: VarInt, reason: Bytes) {
79 self.conn.close(Instant::now(), error_code, reason);
80 self.terminate(ConnectionError::LocallyClosed);
81 self.wake();
82 }
83
84 pub(crate) fn wake(&mut self) {
85 if let Some(waker) = self.poller.take() {
86 waker.wake()
87 }
88 }
89
90 #[cfg(rustls)]
91 fn handshake_data(&self) -> Option<Box<HandshakeData>> {
92 self.conn
93 .crypto_session()
94 .handshake_data()
95 .map(|data| data.downcast::<HandshakeData>().unwrap())
96 }
97
98 pub(crate) fn check_0rtt(&self) -> bool {
99 self.conn.side().is_server() || self.conn.is_handshaking() || self.conn.accepted_0rtt()
100 }
101}
102
103fn wake_stream(stream: StreamId, wakers: &mut HashMap<StreamId, Waker>) {
104 if let Some(waker) = wakers.remove(&stream) {
105 waker.wake();
106 }
107}
108
109fn wake_all_streams(wakers: &mut HashMap<StreamId, Waker>) {
110 wakers.drain().for_each(|(_, waker)| waker.wake())
111}
112
113#[derive(Debug)]
114pub(crate) struct ConnectionInner {
115 state: Mutex<ConnectionState>,
116 handle: ConnectionHandle,
117 socket: Socket,
118 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
119 events_rx: Receiver<ConnectionEvent>,
120}
121
122fn implicit_close(this: &Arc<ConnectionInner>) {
123 if Arc::strong_count(this) == 2 {
124 this.state().close(0u32.into(), Bytes::new())
125 }
126}
127
128impl ConnectionInner {
129 fn new(
130 handle: ConnectionHandle,
131 conn: quinn_proto::Connection,
132 socket: Socket,
133 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
134 events_rx: Receiver<ConnectionEvent>,
135 ) -> Self {
136 Self {
137 state: Mutex::new(ConnectionState {
138 conn,
139 connected: false,
140 error: None,
141 worker: None,
142 poller: None,
143 on_connected: None,
144 on_handshake_data: None,
145 datagram_received: VecDeque::new(),
146 datagrams_unblocked: VecDeque::new(),
147 stream_opened: [VecDeque::new(), VecDeque::new()],
148 stream_available: [VecDeque::new(), VecDeque::new()],
149 writable: HashMap::default(),
150 readable: HashMap::default(),
151 stopped: HashMap::default(),
152 }),
153 handle,
154 socket,
155 events_tx,
156 events_rx,
157 }
158 }
159
160 #[inline]
161 pub(crate) fn state(&self) -> MutexGuard<'_, ConnectionState> {
162 self.state.lock().unwrap()
163 }
164
165 #[inline]
166 pub(crate) fn try_state(&self) -> Result<MutexGuard<'_, ConnectionState>, ConnectionError> {
167 let state = self.state();
168 if let Some(error) = &state.error {
169 Err(error.clone())
170 } else {
171 Ok(state)
172 }
173 }
174
175 async fn run(&self) {
176 let mut poller = stream::poll_fn(|cx| {
177 let mut state = self.state();
178 let ready = state.poller.is_none();
179 match &state.poller {
180 Some(waker) if waker.will_wake(cx.waker()) => {}
181 _ => state.poller = Some(cx.waker().clone()),
182 };
183 if ready {
184 Poll::Ready(Some(()))
185 } else {
186 Poll::Pending
187 }
188 })
189 .fuse();
190
191 let mut timer = Timer::new();
192 let mut event_stream = self.events_rx.stream().ready_chunks(100);
193 let mut send_buf = Some(Vec::with_capacity(self.state().conn.current_mtu() as usize));
194 let mut transmit_fut = pin!(Fuse::terminated());
195
196 loop {
197 let mut state = select! {
198 _ = poller.select_next_some() => self.state(),
199 _ = timer => {
200 timer.reset(None);
201 let mut state = self.state();
202 state.conn.handle_timeout(Instant::now());
203 state
204 }
205 events = event_stream.select_next_some() => {
206 let mut state = self.state();
207 for event in events {
208 match event {
209 ConnectionEvent::Close(error_code, reason) => state.close(error_code, reason),
210 ConnectionEvent::Proto(event) => state.conn.handle_event(event),
211 }
212 }
213 state
214 },
215 BufResult::<(), Vec<u8>>(res, mut buf) = transmit_fut => {
216 #[allow(unused)]
217 if let Err(e) = res {
218 error!("I/O error: {}", e);
219 }
220 buf.clear();
221 send_buf = Some(buf);
222 self.state()
223 },
224 };
225
226 if let Some(mut buf) = send_buf.take() {
227 if let Some(transmit) = state.conn.poll_transmit(
228 Instant::now(),
229 self.socket.max_gso_segments(),
230 &mut buf,
231 ) {
232 transmit_fut.set(async move { self.socket.send(buf, &transmit).await }.fuse())
233 } else {
234 send_buf = Some(buf);
235 }
236 }
237
238 timer.reset(state.conn.poll_timeout());
239
240 while let Some(event) = state.conn.poll_endpoint_events() {
241 let _ = self.events_tx.send((self.handle, event));
242 }
243
244 while let Some(event) = state.conn.poll() {
245 use quinn_proto::Event::*;
246 match event {
247 HandshakeDataReady => {
248 if let Some(waker) = state.on_handshake_data.take() {
249 waker.wake()
250 }
251 }
252 Connected => {
253 state.connected = true;
254 if let Some(waker) = state.on_connected.take() {
255 waker.wake()
256 }
257 if state.conn.side().is_client() && !state.conn.accepted_0rtt() {
258 wake_all_streams(&mut state.writable);
261 wake_all_streams(&mut state.readable);
262 wake_all_streams(&mut state.stopped);
263 }
264 }
265 ConnectionLost { reason } => state.terminate(reason.into()),
266 Stream(StreamEvent::Readable { id }) => wake_stream(id, &mut state.readable),
267 Stream(StreamEvent::Writable { id }) => wake_stream(id, &mut state.writable),
268 Stream(StreamEvent::Finished { id }) => wake_stream(id, &mut state.stopped),
269 Stream(StreamEvent::Stopped { id, .. }) => {
270 wake_stream(id, &mut state.stopped);
271 wake_stream(id, &mut state.writable);
272 }
273 Stream(StreamEvent::Available { dir }) => state.stream_available[dir as usize]
274 .drain(..)
275 .for_each(Waker::wake),
276 Stream(StreamEvent::Opened { dir }) => state.stream_opened[dir as usize]
277 .drain(..)
278 .for_each(Waker::wake),
279 DatagramReceived => state.datagram_received.drain(..).for_each(Waker::wake),
280 DatagramsUnblocked => state.datagrams_unblocked.drain(..).for_each(Waker::wake),
281 }
282 }
283
284 if state.conn.is_drained() {
285 break;
286 }
287 }
288 }
289}
290
291macro_rules! conn_fn {
292 () => {
293 pub fn local_ip(&self) -> Option<IpAddr> {
302 self.0.state().conn.local_ip()
303 }
304
305 pub fn remote_address(&self) -> SocketAddr {
309 self.0.state().conn.remote_address()
310 }
311
312 pub fn rtt(&self) -> Duration {
314 self.0.state().conn.rtt()
315 }
316
317 pub fn stats(&self) -> ConnectionStats {
319 self.0.state().conn.stats()
320 }
321
322 pub fn congestion_state(&self) -> Box<dyn Controller> {
325 self.0.state().conn.congestion_state().clone_box()
326 }
327
328 pub fn peer_identity(
330 &self,
331 ) -> Option<Box<Vec<rustls::pki_types::CertificateDer<'static>>>> {
332 self.0
333 .state()
334 .conn
335 .crypto_session()
336 .peer_identity()
337 .map(|v| v.downcast().unwrap())
338 }
339
340 pub fn export_keying_material(
352 &self,
353 output: &mut [u8],
354 label: &[u8],
355 context: &[u8],
356 ) -> Result<(), quinn_proto::crypto::ExportKeyingMaterialError> {
357 self.0
358 .state()
359 .conn
360 .crypto_session()
361 .export_keying_material(output, label, context)
362 }
363 };
364}
365
366#[derive(Debug)]
368#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
369pub struct Connecting(Arc<ConnectionInner>);
370
371impl Connecting {
372 conn_fn!();
373
374 pub(crate) fn new(
375 handle: ConnectionHandle,
376 conn: quinn_proto::Connection,
377 socket: Socket,
378 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
379 events_rx: Receiver<ConnectionEvent>,
380 ) -> Self {
381 let inner = Arc::new(ConnectionInner::new(
382 handle, conn, socket, events_tx, events_rx,
383 ));
384 let worker = compio_runtime::spawn({
385 let inner = inner.clone();
386 async move { inner.run().await }.in_current_span()
387 });
388 inner.state().worker = Some(worker);
389 Self(inner)
390 }
391
392 #[cfg(rustls)]
394 pub async fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
395 future::poll_fn(|cx| {
396 let mut state = self.0.try_state()?;
397 if let Some(data) = state.handshake_data() {
398 return Poll::Ready(Ok(data));
399 }
400
401 match &state.on_handshake_data {
402 Some(waker) if waker.will_wake(cx.waker()) => {}
403 _ => state.on_handshake_data = Some(cx.waker().clone()),
404 }
405
406 Poll::Pending
407 })
408 .await
409 }
410
411 pub fn into_0rtt(self) -> Result<Connection, Self> {
459 let is_ok = {
460 let state = self.0.state();
461 state.conn.has_0rtt() || state.conn.side().is_server()
462 };
463 if is_ok {
464 Ok(Connection(self.0.clone()))
465 } else {
466 Err(self)
467 }
468 }
469}
470
471impl Future for Connecting {
472 type Output = Result<Connection, ConnectionError>;
473
474 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
475 let mut state = self.0.try_state()?;
476
477 if state.connected {
478 return Poll::Ready(Ok(Connection(self.0.clone())));
479 }
480
481 match &state.on_connected {
482 Some(waker) if waker.will_wake(cx.waker()) => {}
483 _ => state.on_connected = Some(cx.waker().clone()),
484 }
485
486 Poll::Pending
487 }
488}
489
490impl Drop for Connecting {
491 fn drop(&mut self) {
492 implicit_close(&self.0)
493 }
494}
495
496#[derive(Debug, Clone)]
498pub struct Connection(Arc<ConnectionInner>);
499
500impl Connection {
501 conn_fn!();
502
503 #[cfg(rustls)]
505 pub fn handshake_data(&mut self) -> Result<Box<HandshakeData>, ConnectionError> {
506 Ok(self.0.try_state()?.handshake_data().unwrap())
507 }
508
509 pub fn max_datagram_size(&self) -> Option<usize> {
522 self.0.state().conn.datagrams().max_size()
523 }
524
525 pub fn datagram_send_buffer_space(&self) -> usize {
531 self.0.state().conn.datagrams().send_buffer_space()
532 }
533
534 pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
541 let mut state = self.0.state();
542 state.conn.set_max_concurrent_streams(Dir::Uni, count);
543 state.wake();
545 }
546
547 pub fn set_receive_window(&self, receive_window: VarInt) {
549 let mut state = self.0.state();
550 state.conn.set_receive_window(receive_window);
551 state.wake();
552 }
553
554 pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
561 let mut state = self.0.state();
562 state.conn.set_max_concurrent_streams(Dir::Bi, count);
563 state.wake();
565 }
566
567 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
604 self.0
605 .state()
606 .close(error_code, Bytes::copy_from_slice(reason));
607 }
608
609 pub async fn closed(&self) -> ConnectionError {
611 let worker = self.0.state().worker.take();
612 if let Some(worker) = worker {
613 let _ = worker.await;
614 }
615
616 self.0.try_state().unwrap_err()
617 }
618
619 pub fn close_reason(&self) -> Option<ConnectionError> {
623 self.0.try_state().err()
624 }
625
626 fn poll_recv_datagram(&self, cx: &mut Context) -> Poll<Result<Bytes, ConnectionError>> {
627 let mut state = self.0.try_state()?;
628 if let Some(bytes) = state.conn.datagrams().recv() {
629 return Poll::Ready(Ok(bytes));
630 }
631 state.datagram_received.push_back(cx.waker().clone());
632 Poll::Pending
633 }
634
635 pub async fn recv_datagram(&self) -> Result<Bytes, ConnectionError> {
637 future::poll_fn(|cx| self.poll_recv_datagram(cx)).await
638 }
639
640 fn try_send_datagram(
641 &self,
642 cx: Option<&mut Context>,
643 data: Bytes,
644 ) -> Result<(), Result<SendDatagramError, Bytes>> {
645 use quinn_proto::SendDatagramError::*;
646 let mut state = self.0.try_state().map_err(|e| Ok(e.into()))?;
647 state
648 .conn
649 .datagrams()
650 .send(data, cx.is_none())
651 .map_err(|err| match err {
652 UnsupportedByPeer => Ok(SendDatagramError::UnsupportedByPeer),
653 Disabled => Ok(SendDatagramError::Disabled),
654 TooLarge => Ok(SendDatagramError::TooLarge),
655 Blocked(data) => {
656 state
657 .datagrams_unblocked
658 .push_back(cx.unwrap().waker().clone());
659 Err(data)
660 }
661 })?;
662 state.wake();
663 Ok(())
664 }
665
666 pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
672 self.try_send_datagram(None, data).map_err(Result::unwrap)
673 }
674
675 pub async fn send_datagram_wait(&self, data: Bytes) -> Result<(), SendDatagramError> {
685 let mut data = Some(data);
686 future::poll_fn(
687 |cx| match self.try_send_datagram(Some(cx), data.take().unwrap()) {
688 Ok(()) => Poll::Ready(Ok(())),
689 Err(Ok(e)) => Poll::Ready(Err(e)),
690 Err(Err(b)) => {
691 data.replace(b);
692 Poll::Pending
693 }
694 },
695 )
696 .await
697 }
698
699 fn poll_open_stream(
700 &self,
701 cx: Option<&mut Context>,
702 dir: Dir,
703 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
704 let mut state = self.0.try_state()?;
705 if let Some(stream) = state.conn.streams().open(dir) {
706 Poll::Ready(Ok((
707 stream,
708 state.conn.side().is_client() && state.conn.is_handshaking(),
709 )))
710 } else {
711 if let Some(cx) = cx {
712 state.stream_available[dir as usize].push_back(cx.waker().clone());
713 }
714 Poll::Pending
715 }
716 }
717
718 pub fn open_uni(&self) -> Result<SendStream, OpenStreamError> {
724 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Uni)? {
725 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
726 } else {
727 Err(OpenStreamError::StreamsExhausted)
728 }
729 }
730
731 pub async fn open_uni_wait(&self) -> Result<SendStream, ConnectionError> {
740 let (stream, is_0rtt) =
741 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Uni)).await?;
742 Ok(SendStream::new(self.0.clone(), stream, is_0rtt))
743 }
744
745 pub fn open_bi(&self) -> Result<(SendStream, RecvStream), OpenStreamError> {
751 if let Poll::Ready((stream, is_0rtt)) = self.poll_open_stream(None, Dir::Bi)? {
752 Ok((
753 SendStream::new(self.0.clone(), stream, is_0rtt),
754 RecvStream::new(self.0.clone(), stream, is_0rtt),
755 ))
756 } else {
757 Err(OpenStreamError::StreamsExhausted)
758 }
759 }
760
761 pub async fn open_bi_wait(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
770 let (stream, is_0rtt) =
771 future::poll_fn(|cx| self.poll_open_stream(Some(cx), Dir::Bi)).await?;
772 Ok((
773 SendStream::new(self.0.clone(), stream, is_0rtt),
774 RecvStream::new(self.0.clone(), stream, is_0rtt),
775 ))
776 }
777
778 fn poll_accept_stream(
779 &self,
780 cx: &mut Context,
781 dir: Dir,
782 ) -> Poll<Result<(StreamId, bool), ConnectionError>> {
783 let mut state = self.0.try_state()?;
784 if let Some(stream) = state.conn.streams().accept(dir) {
785 state.wake();
786 Poll::Ready(Ok((stream, state.conn.is_handshaking())))
787 } else {
788 state.stream_opened[dir as usize].push_back(cx.waker().clone());
789 Poll::Pending
790 }
791 }
792
793 pub async fn accept_uni(&self) -> Result<RecvStream, ConnectionError> {
795 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Uni)).await?;
796 Ok(RecvStream::new(self.0.clone(), stream, is_0rtt))
797 }
798
799 pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
811 let (stream, is_0rtt) = future::poll_fn(|cx| self.poll_accept_stream(cx, Dir::Bi)).await?;
812 Ok((
813 SendStream::new(self.0.clone(), stream, is_0rtt),
814 RecvStream::new(self.0.clone(), stream, is_0rtt),
815 ))
816 }
817
818 pub async fn accepted_0rtt(&self) -> Result<bool, ConnectionError> {
823 future::poll_fn(|cx| {
824 let mut state = self.0.try_state()?;
825
826 if state.connected {
827 return Poll::Ready(Ok(state.conn.accepted_0rtt()));
828 }
829
830 match &state.on_connected {
831 Some(waker) if waker.will_wake(cx.waker()) => {}
832 _ => state.on_connected = Some(cx.waker().clone()),
833 }
834
835 Poll::Pending
836 })
837 .await
838 }
839}
840
841impl PartialEq for Connection {
842 fn eq(&self, other: &Self) -> bool {
843 Arc::ptr_eq(&self.0, &other.0)
844 }
845}
846
847impl Eq for Connection {}
848
849impl Drop for Connection {
850 fn drop(&mut self) {
851 implicit_close(&self.0)
852 }
853}
854
855struct Timer {
856 deadline: Option<Instant>,
857 fut: Fuse<LocalBoxFuture<'static, ()>>,
858}
859
860impl Timer {
861 fn new() -> Self {
862 Self {
863 deadline: None,
864 fut: Fuse::terminated(),
865 }
866 }
867
868 fn reset(&mut self, deadline: Option<Instant>) {
869 if let Some(deadline) = deadline {
870 if self.deadline.is_none() || self.deadline != Some(deadline) {
871 self.fut = compio_runtime::time::sleep_until(deadline)
872 .boxed_local()
873 .fuse();
874 }
875 } else {
876 self.fut = Fuse::terminated();
877 }
878 self.deadline = deadline;
879 }
880}
881
882impl Future for Timer {
883 type Output = ();
884
885 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
886 self.fut.poll_unpin(cx)
887 }
888}
889
890impl FusedFuture for Timer {
891 fn is_terminated(&self) -> bool {
892 self.fut.is_terminated()
893 }
894}
895
896#[derive(Debug, Error, Clone, PartialEq, Eq)]
898pub enum ConnectionError {
899 #[error("peer doesn't implement any supported version")]
901 VersionMismatch,
902 #[error(transparent)]
905 TransportError(#[from] quinn_proto::TransportError),
906 #[error("aborted by peer: {0}")]
908 ConnectionClosed(quinn_proto::ConnectionClose),
909 #[error("closed by peer: {0}")]
911 ApplicationClosed(quinn_proto::ApplicationClose),
912 #[error("reset by peer")]
915 Reset,
916 #[error("timed out")]
924 TimedOut,
925 #[error("closed")]
927 LocallyClosed,
928 #[error("CIDs exhausted")]
933 CidsExhausted,
934}
935
936impl From<quinn_proto::ConnectionError> for ConnectionError {
937 fn from(value: quinn_proto::ConnectionError) -> Self {
938 use quinn_proto::ConnectionError::*;
939
940 match value {
941 VersionMismatch => ConnectionError::VersionMismatch,
942 TransportError(e) => ConnectionError::TransportError(e),
943 ConnectionClosed(e) => ConnectionError::ConnectionClosed(e),
944 ApplicationClosed(e) => ConnectionError::ApplicationClosed(e),
945 Reset => ConnectionError::Reset,
946 TimedOut => ConnectionError::TimedOut,
947 LocallyClosed => ConnectionError::LocallyClosed,
948 CidsExhausted => ConnectionError::CidsExhausted,
949 }
950 }
951}
952
953#[derive(Debug, Error, Clone, Eq, PartialEq)]
955pub enum SendDatagramError {
956 #[error("datagrams not supported by peer")]
958 UnsupportedByPeer,
959 #[error("datagram support disabled")]
961 Disabled,
962 #[error("datagram too large")]
967 TooLarge,
968 #[error("connection lost")]
970 ConnectionLost(#[from] ConnectionError),
971}
972
973#[derive(Debug, Error, Clone, Eq, PartialEq)]
975pub enum OpenStreamError {
976 #[error("connection lost")]
978 ConnectionLost(#[from] ConnectionError),
979 #[error("streams exhausted")]
981 StreamsExhausted,
982}
983
984#[cfg(feature = "h3")]
985pub(crate) mod h3_impl {
986 use compio_buf::bytes::Buf;
987 use futures_util::ready;
988 use h3::{
989 error::Code,
990 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, WriteBuf},
991 };
992 use h3_datagram::{
993 datagram::EncodedDatagram,
994 quic_traits::{
995 DatagramConnectionExt, RecvDatagram, SendDatagram, SendDatagramErrorIncoming,
996 },
997 };
998
999 use super::*;
1000 use crate::send_stream::h3_impl::SendStream;
1001
1002 impl From<ConnectionError> for ConnectionErrorIncoming {
1003 fn from(e: ConnectionError) -> Self {
1004 use ConnectionError::*;
1005 match e {
1006 ApplicationClosed(e) => Self::ApplicationClose {
1007 error_code: e.error_code.into_inner(),
1008 },
1009 TimedOut => Self::Timeout,
1010
1011 e => Self::Undefined(Arc::new(e)),
1012 }
1013 }
1014 }
1015
1016 impl From<ConnectionError> for StreamErrorIncoming {
1017 fn from(e: ConnectionError) -> Self {
1018 Self::ConnectionErrorIncoming {
1019 connection_error: e.into(),
1020 }
1021 }
1022 }
1023
1024 impl From<SendDatagramError> for SendDatagramErrorIncoming {
1025 fn from(e: SendDatagramError) -> Self {
1026 use SendDatagramError::*;
1027 match e {
1028 UnsupportedByPeer | Disabled => Self::NotAvailable,
1029 TooLarge => Self::TooLarge,
1030 ConnectionLost(e) => Self::ConnectionError(e.into()),
1031 }
1032 }
1033 }
1034
1035 impl<B> SendDatagram<B> for Connection
1036 where
1037 B: Buf,
1038 {
1039 fn send_datagram<T: Into<EncodedDatagram<B>>>(
1040 &mut self,
1041 data: T,
1042 ) -> Result<(), SendDatagramErrorIncoming> {
1043 let mut buf: EncodedDatagram<B> = data.into();
1044 let buf = buf.copy_to_bytes(buf.remaining());
1045 Ok(Connection::send_datagram(self, buf)?)
1046 }
1047 }
1048
1049 impl RecvDatagram for Connection {
1050 type Buffer = Bytes;
1051
1052 fn poll_incoming_datagram(
1053 &mut self,
1054 cx: &mut core::task::Context<'_>,
1055 ) -> Poll<Result<Self::Buffer, ConnectionErrorIncoming>> {
1056 Poll::Ready(Ok(ready!(self.poll_recv_datagram(cx))?))
1057 }
1058 }
1059
1060 impl<B: Buf> DatagramConnectionExt<B> for Connection {
1061 type RecvDatagramHandler = Self;
1062 type SendDatagramHandler = Self;
1063
1064 fn send_datagram_handler(&self) -> Self::SendDatagramHandler {
1065 self.clone()
1066 }
1067
1068 fn recv_datagram_handler(&self) -> Self::RecvDatagramHandler {
1069 self.clone()
1070 }
1071 }
1072
1073 pub struct BidiStream<B> {
1075 send: SendStream<B>,
1076 recv: RecvStream,
1077 }
1078
1079 impl<B> BidiStream<B> {
1080 pub(crate) fn new(conn: Arc<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
1081 Self {
1082 send: SendStream::new(conn.clone(), stream, is_0rtt),
1083 recv: RecvStream::new(conn, stream, is_0rtt),
1084 }
1085 }
1086 }
1087
1088 impl<B> quic::BidiStream<B> for BidiStream<B>
1089 where
1090 B: Buf,
1091 {
1092 type RecvStream = RecvStream;
1093 type SendStream = SendStream<B>;
1094
1095 fn split(self) -> (Self::SendStream, Self::RecvStream) {
1096 (self.send, self.recv)
1097 }
1098 }
1099
1100 impl<B> quic::RecvStream for BidiStream<B>
1101 where
1102 B: Buf,
1103 {
1104 type Buf = Bytes;
1105
1106 fn poll_data(
1107 &mut self,
1108 cx: &mut Context<'_>,
1109 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
1110 self.recv.poll_data(cx)
1111 }
1112
1113 fn stop_sending(&mut self, error_code: u64) {
1114 self.recv.stop_sending(error_code)
1115 }
1116
1117 fn recv_id(&self) -> quic::StreamId {
1118 self.recv.recv_id()
1119 }
1120 }
1121
1122 impl<B> quic::SendStream<B> for BidiStream<B>
1123 where
1124 B: Buf,
1125 {
1126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1127 self.send.poll_ready(cx)
1128 }
1129
1130 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
1131 self.send.send_data(data)
1132 }
1133
1134 fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
1135 self.send.poll_finish(cx)
1136 }
1137
1138 fn reset(&mut self, reset_code: u64) {
1139 self.send.reset(reset_code)
1140 }
1141
1142 fn send_id(&self) -> quic::StreamId {
1143 self.send.send_id()
1144 }
1145 }
1146
1147 impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
1148 where
1149 B: Buf,
1150 {
1151 fn poll_send<D: Buf>(
1152 &mut self,
1153 cx: &mut Context<'_>,
1154 buf: &mut D,
1155 ) -> Poll<Result<usize, StreamErrorIncoming>> {
1156 self.send.poll_send(cx, buf)
1157 }
1158 }
1159
1160 #[derive(Clone)]
1162 pub struct OpenStreams(Connection);
1163
1164 impl<B> quic::OpenStreams<B> for OpenStreams
1165 where
1166 B: Buf,
1167 {
1168 type BidiStream = BidiStream<B>;
1169 type SendStream = SendStream<B>;
1170
1171 fn poll_open_bidi(
1172 &mut self,
1173 cx: &mut Context<'_>,
1174 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1175 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Bi))?;
1176 Poll::Ready(Ok(BidiStream::new(self.0.0.clone(), stream, is_0rtt)))
1177 }
1178
1179 fn poll_open_send(
1180 &mut self,
1181 cx: &mut Context<'_>,
1182 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1183 let (stream, is_0rtt) = ready!(self.0.poll_open_stream(Some(cx), Dir::Uni))?;
1184 Poll::Ready(Ok(SendStream::new(self.0.0.clone(), stream, is_0rtt)))
1185 }
1186
1187 fn close(&mut self, code: Code, reason: &[u8]) {
1188 self.0
1189 .close(code.value().try_into().expect("invalid code"), reason)
1190 }
1191 }
1192
1193 impl<B> quic::OpenStreams<B> for Connection
1194 where
1195 B: Buf,
1196 {
1197 type BidiStream = BidiStream<B>;
1198 type SendStream = SendStream<B>;
1199
1200 fn poll_open_bidi(
1201 &mut self,
1202 cx: &mut Context<'_>,
1203 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
1204 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Bi))?;
1205 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1206 }
1207
1208 fn poll_open_send(
1209 &mut self,
1210 cx: &mut Context<'_>,
1211 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
1212 let (stream, is_0rtt) = ready!(self.poll_open_stream(Some(cx), Dir::Uni))?;
1213 Poll::Ready(Ok(SendStream::new(self.0.clone(), stream, is_0rtt)))
1214 }
1215
1216 fn close(&mut self, code: Code, reason: &[u8]) {
1217 Connection::close(self, code.value().try_into().expect("invalid code"), reason)
1218 }
1219 }
1220
1221 impl<B> quic::Connection<B> for Connection
1222 where
1223 B: Buf,
1224 {
1225 type OpenStreams = OpenStreams;
1226 type RecvStream = RecvStream;
1227
1228 fn poll_accept_recv(
1229 &mut self,
1230 cx: &mut std::task::Context<'_>,
1231 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
1232 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Uni))?;
1233 Poll::Ready(Ok(RecvStream::new(self.0.clone(), stream, is_0rtt)))
1234 }
1235
1236 fn poll_accept_bidi(
1237 &mut self,
1238 cx: &mut std::task::Context<'_>,
1239 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
1240 let (stream, is_0rtt) = ready!(self.poll_accept_stream(cx, Dir::Bi))?;
1241 Poll::Ready(Ok(BidiStream::new(self.0.clone(), stream, is_0rtt)))
1242 }
1243
1244 fn opener(&self) -> Self::OpenStreams {
1245 OpenStreams(self.clone())
1246 }
1247 }
1248}