1use alloc::boxed::Box;
2use alloc::collections::VecDeque;
3use alloc::vec::Vec;
4#[cfg(feature = "std")]
5use core::fmt::Debug;
6
7use crate::common_state::Side;
9use crate::crypto::cipher::{AeadKey, Iv};
10use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock};
11use crate::enums::AlertDescription;
12use crate::error::Error;
13use crate::tls13::Tls13CipherSuite;
14use crate::tls13::key_schedule::{
15 hkdf_expand_label, hkdf_expand_label_aead_key, hkdf_expand_label_block,
16};
17
18#[cfg(feature = "std")]
19mod connection {
20 use alloc::vec::Vec;
21 use core::fmt::{self, Debug};
22 use core::ops::{Deref, DerefMut};
23
24 use pki_types::{DnsName, ServerName};
25
26 use super::{DirectionalKeys, KeyChange, Version};
27 use crate::client::{ClientConfig, ClientConnectionData};
28 use crate::common_state::{CommonState, DEFAULT_BUFFER_LIMIT, Protocol};
29 use crate::conn::{ConnectionCore, SideData};
30 use crate::enums::{AlertDescription, ContentType, ProtocolVersion};
31 use crate::error::Error;
32 use crate::msgs::base::Payload;
33 use crate::msgs::deframer::buffers::{DeframerVecBuffer, Locator};
34 use crate::msgs::handshake::{
35 ClientExtensionsInput, ServerExtensionsInput, TransportParameters,
36 };
37 use crate::msgs::message::InboundPlainMessage;
38 use crate::server::{ServerConfig, ServerConnectionData};
39 use crate::sync::Arc;
40 use crate::vecbuf::ChunkVecBuffer;
41
42 #[allow(clippy::exhaustive_enums)]
44 #[derive(Debug)]
45 pub enum Connection {
46 Client(ClientConnection),
48 Server(ServerConnection),
50 }
51
52 impl Connection {
53 pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
57 match self {
58 Self::Client(conn) => conn.quic_transport_parameters(),
59 Self::Server(conn) => conn.quic_transport_parameters(),
60 }
61 }
62
63 pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
65 match self {
66 Self::Client(conn) => conn.zero_rtt_keys(),
67 Self::Server(conn) => conn.zero_rtt_keys(),
68 }
69 }
70
71 pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
75 match self {
76 Self::Client(conn) => conn.read_hs(plaintext),
77 Self::Server(conn) => conn.read_hs(plaintext),
78 }
79 }
80
81 pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
85 match self {
86 Self::Client(conn) => conn.write_hs(buf),
87 Self::Server(conn) => conn.write_hs(buf),
88 }
89 }
90
91 pub fn alert(&self) -> Option<AlertDescription> {
95 match self {
96 Self::Client(conn) => conn.alert(),
97 Self::Server(conn) => conn.alert(),
98 }
99 }
100
101 #[inline]
117 pub fn export_keying_material<T: AsMut<[u8]>>(
118 &self,
119 output: T,
120 label: &[u8],
121 context: Option<&[u8]>,
122 ) -> Result<T, Error> {
123 match self {
124 Self::Client(conn) => conn
125 .core
126 .export_keying_material(output, label, context),
127 Self::Server(conn) => conn
128 .core
129 .export_keying_material(output, label, context),
130 }
131 }
132 }
133
134 impl Deref for Connection {
135 type Target = CommonState;
136
137 fn deref(&self) -> &Self::Target {
138 match self {
139 Self::Client(conn) => &conn.core.common_state,
140 Self::Server(conn) => &conn.core.common_state,
141 }
142 }
143 }
144
145 impl DerefMut for Connection {
146 fn deref_mut(&mut self) -> &mut Self::Target {
147 match self {
148 Self::Client(conn) => &mut conn.core.common_state,
149 Self::Server(conn) => &mut conn.core.common_state,
150 }
151 }
152 }
153
154 pub struct ClientConnection {
156 inner: ConnectionCommon<ClientConnectionData>,
157 }
158
159 impl ClientConnection {
160 pub fn new(
165 config: Arc<ClientConfig>,
166 quic_version: Version,
167 name: ServerName<'static>,
168 params: Vec<u8>,
169 ) -> Result<Self, Error> {
170 Self::new_with_alpn(
171 config.clone(),
172 quic_version,
173 name,
174 params,
175 config.alpn_protocols.clone(),
176 )
177 }
178
179 pub fn new_with_alpn(
181 config: Arc<ClientConfig>,
182 quic_version: Version,
183 name: ServerName<'static>,
184 params: Vec<u8>,
185 alpn_protocols: Vec<Vec<u8>>,
186 ) -> Result<Self, Error> {
187 if !config.supports_version(ProtocolVersion::TLSv1_3) {
188 return Err(Error::General(
189 "TLS 1.3 support is required for QUIC".into(),
190 ));
191 }
192
193 if !config.supports_protocol(Protocol::Quic) {
194 return Err(Error::General(
195 "at least one ciphersuite must support QUIC".into(),
196 ));
197 }
198
199 let exts = ClientExtensionsInput {
200 transport_parameters: Some(match quic_version {
201 Version::V1 | Version::V2 => TransportParameters::Quic(Payload::new(params)),
202 }),
203
204 ..ClientExtensionsInput::from_alpn(alpn_protocols)
205 };
206
207 let mut inner = ConnectionCore::for_client(config, name, exts, Protocol::Quic)?;
208 inner.common_state.quic.version = quic_version;
209 Ok(Self {
210 inner: inner.into(),
211 })
212 }
213
214 pub fn is_early_data_accepted(&self) -> bool {
220 self.inner.core.is_early_data_accepted()
221 }
222
223 pub fn tls13_tickets_received(&self) -> u32 {
225 self.inner.tls13_tickets_received
226 }
227 }
228
229 impl Deref for ClientConnection {
230 type Target = ConnectionCommon<ClientConnectionData>;
231
232 fn deref(&self) -> &Self::Target {
233 &self.inner
234 }
235 }
236
237 impl DerefMut for ClientConnection {
238 fn deref_mut(&mut self) -> &mut Self::Target {
239 &mut self.inner
240 }
241 }
242
243 impl Debug for ClientConnection {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 f.debug_struct("quic::ClientConnection")
246 .finish()
247 }
248 }
249
250 impl From<ClientConnection> for Connection {
251 fn from(c: ClientConnection) -> Self {
252 Self::Client(c)
253 }
254 }
255
256 pub struct ServerConnection {
258 inner: ConnectionCommon<ServerConnectionData>,
259 }
260
261 impl ServerConnection {
262 pub fn new(
267 config: Arc<ServerConfig>,
268 quic_version: Version,
269 params: Vec<u8>,
270 ) -> Result<Self, Error> {
271 if !config.supports_version(ProtocolVersion::TLSv1_3) {
272 return Err(Error::General(
273 "TLS 1.3 support is required for QUIC".into(),
274 ));
275 }
276
277 if !config.supports_protocol(Protocol::Quic) {
278 return Err(Error::General(
279 "at least one ciphersuite must support QUIC".into(),
280 ));
281 }
282
283 if config.max_early_data_size != 0 && config.max_early_data_size != 0xffff_ffff {
284 return Err(Error::General(
285 "QUIC sessions must set a max early data of 0 or 2^32-1".into(),
286 ));
287 }
288
289 let exts = ServerExtensionsInput {
290 transport_parameters: Some(match quic_version {
291 Version::V1 | Version::V2 => TransportParameters::Quic(Payload::new(params)),
292 }),
293 };
294
295 let mut core = ConnectionCore::for_server(config, exts)?;
296 core.common_state.protocol = Protocol::Quic;
297 core.common_state.quic.version = quic_version;
298 Ok(Self { inner: core.into() })
299 }
300
301 pub fn reject_early_data(&mut self) {
307 self.inner.core.reject_early_data()
308 }
309
310 pub fn server_name(&self) -> Option<&DnsName<'_>> {
326 self.inner.core.data.sni.as_ref()
327 }
328 }
329
330 impl Deref for ServerConnection {
331 type Target = ConnectionCommon<ServerConnectionData>;
332
333 fn deref(&self) -> &Self::Target {
334 &self.inner
335 }
336 }
337
338 impl DerefMut for ServerConnection {
339 fn deref_mut(&mut self) -> &mut Self::Target {
340 &mut self.inner
341 }
342 }
343
344 impl Debug for ServerConnection {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 f.debug_struct("quic::ServerConnection")
347 .finish()
348 }
349 }
350
351 impl From<ServerConnection> for Connection {
352 fn from(c: ServerConnection) -> Self {
353 Self::Server(c)
354 }
355 }
356
357 pub struct ConnectionCommon<Data> {
359 core: ConnectionCore<Data>,
360 deframer_buffer: DeframerVecBuffer,
361 sendable_plaintext: ChunkVecBuffer,
362 }
363
364 impl<Data: SideData> ConnectionCommon<Data> {
365 pub fn quic_transport_parameters(&self) -> Option<&[u8]> {
373 self.core
374 .common_state
375 .quic
376 .params
377 .as_ref()
378 .map(|v| v.as_ref())
379 }
380
381 pub fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
383 let suite = self
384 .core
385 .common_state
386 .suite
387 .and_then(|suite| suite.tls13())?;
388 Some(DirectionalKeys::new(
389 suite,
390 suite.quic?,
391 self.core
392 .common_state
393 .quic
394 .early_secret
395 .as_ref()?,
396 self.core.common_state.quic.version,
397 ))
398 }
399
400 pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
404 let range = self.deframer_buffer.extend(plaintext);
405
406 self.core.hs_deframer.input_message(
407 InboundPlainMessage {
408 typ: ContentType::Handshake,
409 version: ProtocolVersion::TLSv1_3,
410 payload: &self.deframer_buffer.filled()[range.clone()],
411 },
412 &Locator::new(self.deframer_buffer.filled()),
413 range.end,
414 );
415
416 self.core
417 .hs_deframer
418 .coalesce(self.deframer_buffer.filled_mut())?;
419
420 self.core
421 .process_new_packets(&mut self.deframer_buffer, &mut self.sendable_plaintext)?;
422
423 Ok(())
424 }
425
426 pub fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
430 self.core
431 .common_state
432 .quic
433 .write_hs(buf)
434 }
435
436 pub fn alert(&self) -> Option<AlertDescription> {
440 self.core.common_state.quic.alert
441 }
442 }
443
444 impl<Data> Deref for ConnectionCommon<Data> {
445 type Target = CommonState;
446
447 fn deref(&self) -> &Self::Target {
448 &self.core.common_state
449 }
450 }
451
452 impl<Data> DerefMut for ConnectionCommon<Data> {
453 fn deref_mut(&mut self) -> &mut Self::Target {
454 &mut self.core.common_state
455 }
456 }
457
458 impl<Data> From<ConnectionCore<Data>> for ConnectionCommon<Data> {
459 fn from(core: ConnectionCore<Data>) -> Self {
460 Self {
461 core,
462 deframer_buffer: DeframerVecBuffer::default(),
463 sendable_plaintext: ChunkVecBuffer::new(Some(DEFAULT_BUFFER_LIMIT)),
464 }
465 }
466 }
467}
468
469#[cfg(feature = "std")]
470pub use connection::{ClientConnection, Connection, ConnectionCommon, ServerConnection};
471
472#[derive(Default)]
473pub(crate) struct Quic {
474 pub(crate) params: Option<Vec<u8>>,
476 pub(crate) alert: Option<AlertDescription>,
477 pub(crate) hs_queue: VecDeque<(bool, Vec<u8>)>,
478 pub(crate) early_secret: Option<OkmBlock>,
479 pub(crate) hs_secrets: Option<Secrets>,
480 pub(crate) traffic_secrets: Option<Secrets>,
481 #[cfg(feature = "std")]
483 pub(crate) returned_traffic_keys: bool,
484 pub(crate) version: Version,
485}
486
487#[cfg(feature = "std")]
488impl Quic {
489 pub(crate) fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
490 while let Some((_, msg)) = self.hs_queue.pop_front() {
491 buf.extend_from_slice(&msg);
492 if let Some(&(true, _)) = self.hs_queue.front() {
493 if self.hs_secrets.is_some() {
494 break;
496 }
497 }
498 }
499
500 if let Some(secrets) = self.hs_secrets.take() {
501 return Some(KeyChange::Handshake {
502 keys: Keys::new(&secrets),
503 });
504 }
505
506 if let Some(mut secrets) = self.traffic_secrets.take() {
507 if !self.returned_traffic_keys {
508 self.returned_traffic_keys = true;
509 let keys = Keys::new(&secrets);
510 secrets.update();
511 return Some(KeyChange::OneRtt {
512 keys,
513 next: secrets,
514 });
515 }
516 }
517
518 None
519 }
520}
521
522#[derive(Clone)]
524pub struct Secrets {
525 pub(crate) client: OkmBlock,
527 pub(crate) server: OkmBlock,
529 suite: &'static Tls13CipherSuite,
531 quic: &'static dyn Algorithm,
532 side: Side,
533 version: Version,
534}
535
536impl Secrets {
537 pub(crate) fn new(
538 client: OkmBlock,
539 server: OkmBlock,
540 suite: &'static Tls13CipherSuite,
541 quic: &'static dyn Algorithm,
542 side: Side,
543 version: Version,
544 ) -> Self {
545 Self {
546 client,
547 server,
548 suite,
549 quic,
550 side,
551 version,
552 }
553 }
554
555 pub fn next_packet_keys(&mut self) -> PacketKeySet {
557 let keys = PacketKeySet::new(self);
558 self.update();
559 keys
560 }
561
562 pub(crate) fn update(&mut self) {
563 self.client = hkdf_expand_label_block(
564 self.suite
565 .hkdf_provider
566 .expander_for_okm(&self.client)
567 .as_ref(),
568 self.version.key_update_label(),
569 &[],
570 );
571 self.server = hkdf_expand_label_block(
572 self.suite
573 .hkdf_provider
574 .expander_for_okm(&self.server)
575 .as_ref(),
576 self.version.key_update_label(),
577 &[],
578 );
579 }
580
581 fn local_remote(&self) -> (&OkmBlock, &OkmBlock) {
582 match self.side {
583 Side::Client => (&self.client, &self.server),
584 Side::Server => (&self.server, &self.client),
585 }
586 }
587}
588
589#[allow(clippy::exhaustive_structs)]
591pub struct DirectionalKeys {
592 pub header: Box<dyn HeaderProtectionKey>,
594 pub packet: Box<dyn PacketKey>,
596}
597
598impl DirectionalKeys {
599 pub(crate) fn new(
600 suite: &'static Tls13CipherSuite,
601 quic: &'static dyn Algorithm,
602 secret: &OkmBlock,
603 version: Version,
604 ) -> Self {
605 let builder = KeyBuilder::new(secret, version, quic, suite.hkdf_provider);
606 Self {
607 header: builder.header_protection_key(),
608 packet: builder.packet_key(),
609 }
610 }
611}
612
613const TAG_LEN: usize = 16;
615
616pub struct Tag([u8; TAG_LEN]);
618
619impl From<&[u8]> for Tag {
620 fn from(value: &[u8]) -> Self {
621 let mut array = [0u8; TAG_LEN];
622 array.copy_from_slice(value);
623 Self(array)
624 }
625}
626
627impl AsRef<[u8]> for Tag {
628 fn as_ref(&self) -> &[u8] {
629 &self.0
630 }
631}
632
633pub trait Algorithm: Send + Sync {
635 fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn PacketKey>;
640
641 fn header_protection_key(&self, key: AeadKey) -> Box<dyn HeaderProtectionKey>;
645
646 fn aead_key_len(&self) -> usize;
650
651 fn fips(&self) -> bool {
653 false
654 }
655}
656
657pub trait HeaderProtectionKey: Send + Sync {
659 fn encrypt_in_place(
680 &self,
681 sample: &[u8],
682 first: &mut u8,
683 packet_number: &mut [u8],
684 ) -> Result<(), Error>;
685
686 fn decrypt_in_place(
708 &self,
709 sample: &[u8],
710 first: &mut u8,
711 packet_number: &mut [u8],
712 ) -> Result<(), Error>;
713
714 fn sample_len(&self) -> usize;
716}
717
718pub trait PacketKey: Send + Sync {
720 fn encrypt_in_place(
731 &self,
732 packet_number: u64,
733 header: &[u8],
734 payload: &mut [u8],
735 path_id: Option<u32>,
736 ) -> Result<Tag, Error>;
737
738 fn decrypt_in_place<'a>(
749 &self,
750 packet_number: u64,
751 header: &[u8],
752 payload: &'a mut [u8],
753 path_id: Option<u32>,
754 ) -> Result<&'a [u8], Error>;
755
756 fn tag_len(&self) -> usize;
758
759 fn confidentiality_limit(&self) -> u64;
770
771 fn integrity_limit(&self) -> u64;
780}
781
782#[allow(clippy::exhaustive_structs)]
784pub struct PacketKeySet {
785 pub local: Box<dyn PacketKey>,
787 pub remote: Box<dyn PacketKey>,
789}
790
791impl PacketKeySet {
792 fn new(secrets: &Secrets) -> Self {
793 let (local, remote) = secrets.local_remote();
794 let (version, alg, hkdf) = (secrets.version, secrets.quic, secrets.suite.hkdf_provider);
795 Self {
796 local: KeyBuilder::new(local, version, alg, hkdf).packet_key(),
797 remote: KeyBuilder::new(remote, version, alg, hkdf).packet_key(),
798 }
799 }
800}
801
802pub(crate) struct KeyBuilder<'a> {
803 expander: Box<dyn HkdfExpander>,
804 version: Version,
805 alg: &'a dyn Algorithm,
806}
807
808impl<'a> KeyBuilder<'a> {
809 pub(crate) fn new(
810 secret: &OkmBlock,
811 version: Version,
812 alg: &'a dyn Algorithm,
813 hkdf: &'a dyn Hkdf,
814 ) -> Self {
815 Self {
816 expander: hkdf.expander_for_okm(secret),
817 version,
818 alg,
819 }
820 }
821
822 pub(crate) fn packet_key(&self) -> Box<dyn PacketKey> {
824 let aead_key_len = self.alg.aead_key_len();
825 let packet_key = hkdf_expand_label_aead_key(
826 self.expander.as_ref(),
827 aead_key_len,
828 self.version.packet_key_label(),
829 &[],
830 );
831
832 let packet_iv =
833 hkdf_expand_label(self.expander.as_ref(), self.version.packet_iv_label(), &[]);
834 self.alg
835 .packet_key(packet_key, packet_iv)
836 }
837
838 pub(crate) fn header_protection_key(&self) -> Box<dyn HeaderProtectionKey> {
840 let header_key = hkdf_expand_label_aead_key(
841 self.expander.as_ref(),
842 self.alg.aead_key_len(),
843 self.version.header_key_label(),
844 &[],
845 );
846 self.alg
847 .header_protection_key(header_key)
848 }
849}
850
851#[non_exhaustive]
853#[derive(Clone, Copy)]
854pub struct Suite {
855 pub suite: &'static Tls13CipherSuite,
857 pub quic: &'static dyn Algorithm,
859}
860
861impl Suite {
862 pub fn keys(&self, client_dst_connection_id: &[u8], side: Side, version: Version) -> Keys {
864 Keys::initial(
865 version,
866 self.suite,
867 self.quic,
868 client_dst_connection_id,
869 side,
870 )
871 }
872}
873
874#[allow(clippy::exhaustive_structs)]
876pub struct Keys {
877 pub local: DirectionalKeys,
879 pub remote: DirectionalKeys,
881}
882
883impl Keys {
884 pub fn initial(
886 version: Version,
887 suite: &'static Tls13CipherSuite,
888 quic: &'static dyn Algorithm,
889 client_dst_connection_id: &[u8],
890 side: Side,
891 ) -> Self {
892 const CLIENT_LABEL: &[u8] = b"client in";
893 const SERVER_LABEL: &[u8] = b"server in";
894 let salt = version.initial_salt();
895 let hs_secret = suite
896 .hkdf_provider
897 .extract_from_secret(Some(salt), client_dst_connection_id);
898
899 let secrets = Secrets {
900 version,
901 client: hkdf_expand_label_block(hs_secret.as_ref(), CLIENT_LABEL, &[]),
902 server: hkdf_expand_label_block(hs_secret.as_ref(), SERVER_LABEL, &[]),
903 suite,
904 quic,
905 side,
906 };
907 Self::new(&secrets)
908 }
909
910 fn new(secrets: &Secrets) -> Self {
911 let (local, remote) = secrets.local_remote();
912 Self {
913 local: DirectionalKeys::new(secrets.suite, secrets.quic, local, secrets.version),
914 remote: DirectionalKeys::new(secrets.suite, secrets.quic, remote, secrets.version),
915 }
916 }
917}
918
919#[allow(clippy::exhaustive_enums)]
933pub enum KeyChange {
934 Handshake {
936 keys: Keys,
938 },
939 OneRtt {
941 keys: Keys,
943 next: Secrets,
945 },
946}
947
948#[non_exhaustive]
952#[derive(Clone, Copy, Debug)]
953pub enum Version {
954 V1,
956 V2,
958}
959
960impl Version {
961 fn initial_salt(self) -> &'static [u8; 20] {
962 match self {
963 Self::V1 => &[
964 0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
966 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
967 ],
968 Self::V2 => &[
969 0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26,
971 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9,
972 ],
973 }
974 }
975
976 pub(crate) fn packet_key_label(&self) -> &'static [u8] {
978 match self {
979 Self::V1 => b"quic key",
980 Self::V2 => b"quicv2 key",
981 }
982 }
983
984 pub(crate) fn packet_iv_label(&self) -> &'static [u8] {
986 match self {
987 Self::V1 => b"quic iv",
988 Self::V2 => b"quicv2 iv",
989 }
990 }
991
992 pub(crate) fn header_key_label(&self) -> &'static [u8] {
994 match self {
995 Self::V1 => b"quic hp",
996 Self::V2 => b"quicv2 hp",
997 }
998 }
999
1000 fn key_update_label(&self) -> &'static [u8] {
1001 match self {
1002 Self::V1 => b"quic ku",
1003 Self::V2 => b"quicv2 ku",
1004 }
1005 }
1006}
1007
1008impl Default for Version {
1009 fn default() -> Self {
1010 Self::V1
1011 }
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use std::prelude::v1::*;
1017
1018 use super::PacketKey;
1019 use crate::quic::HeaderProtectionKey;
1020
1021 #[test]
1022 fn auto_traits() {
1023 fn assert_auto<T: Send + Sync>() {}
1024 assert_auto::<Box<dyn PacketKey>>();
1025 assert_auto::<Box<dyn HeaderProtectionKey>>();
1026 }
1027}