Skip to main content

rustls/
quic.rs

1use alloc::boxed::Box;
2use alloc::collections::VecDeque;
3use alloc::vec::Vec;
4use core::fmt::{self, Debug};
5use core::mem;
6use core::ops::{Deref, DerefMut};
7
8use pki_types::{DnsName, FipsStatus, ServerName};
9
10use crate::client::{ClientConfig, ClientSide};
11pub use crate::common_state::Side;
12use crate::common_state::{CommonState, ConnectionOutputs, Protocol};
13use crate::conn::{ConnectionCore, KeyingMaterialExporter, SideData};
14use crate::crypto::cipher::{AeadKey, EncodedMessage, Iv, Payload};
15use crate::crypto::tls13::{Hkdf, HkdfExpander, OkmBlock};
16use crate::enums::{ApplicationProtocol, ContentType, ProtocolVersion};
17use crate::error::{ApiMisuse, Error};
18use crate::msgs::{
19    ClientExtensionsInput, Message, MessagePayload, ServerExtensionsInput, TransportParameters,
20    VecInput,
21};
22use crate::server::{ChooseConfig, ClientHello, ServerConfig, ServerSide, ServerState};
23use crate::suites::SupportedCipherSuite;
24use crate::sync::Arc;
25use crate::tls13::Tls13CipherSuite;
26use crate::tls13::key_schedule::{
27    hkdf_expand_label, hkdf_expand_label_aead_key, hkdf_expand_label_block,
28};
29
30/// A QUIC client or server connection.
31pub trait Connection: Debug + Deref<Target = ConnectionOutputs> {
32    /// Return the TLS-encoded transport parameters for the session's peer.
33    ///
34    /// While the transport parameters are technically available prior to the
35    /// completion of the handshake, they cannot be fully trusted until the
36    /// handshake completes, and reliance on them should be minimized.
37    /// However, any tampering with the parameters will cause the handshake
38    /// to fail.
39    fn quic_transport_parameters(&self) -> Option<&[u8]>;
40
41    /// Compute the keys for encrypting/decrypting 0-RTT packets, if available
42    fn zero_rtt_keys(&self) -> Option<DirectionalKeys>;
43
44    /// Consume unencrypted TLS handshake data.
45    ///
46    /// Handshake data obtained from separate encryption levels should be supplied in separate calls.
47    fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error>;
48
49    /// Emit unencrypted TLS handshake data.
50    ///
51    /// When this returns `Some(_)`, the new keys must be used for future handshake data.
52    fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange>;
53
54    /// Returns true if the connection is currently performing the TLS handshake.
55    fn is_handshaking(&self) -> bool;
56}
57
58/// A QUIC client connection.
59pub struct ClientConnection {
60    inner: ConnectionCommon<ClientSide>,
61}
62
63impl ClientConnection {
64    /// Make a new QUIC ClientConnection.
65    ///
66    /// This differs from `ClientConnection::new()` in that it takes an extra `params` argument,
67    /// which contains the TLS-encoded transport parameters to send.
68    pub fn new(
69        config: Arc<ClientConfig>,
70        quic_version: Version,
71        name: ServerName<'static>,
72        params: Vec<u8>,
73    ) -> Result<Self, Error> {
74        let alpn_protocols = config.alpn_protocols.clone();
75        Self::new_with_alpn(config, quic_version, name, params, alpn_protocols)
76    }
77
78    /// Make a new QUIC ClientConnection with custom ALPN protocols.
79    pub fn new_with_alpn(
80        config: Arc<ClientConfig>,
81        version: Version,
82        name: ServerName<'static>,
83        params: Vec<u8>,
84        alpn_protocols: Vec<ApplicationProtocol<'static>>,
85    ) -> Result<Self, Error> {
86        let suites = &config.provider().tls13_cipher_suites;
87        if suites.is_empty() {
88            return Err(ApiMisuse::QuicRequiresTls13Support.into());
89        }
90
91        if !suites
92            .iter()
93            .any(|scs| scs.quic.is_some())
94        {
95            return Err(ApiMisuse::NoQuicCompatibleCipherSuites.into());
96        }
97
98        let exts = ClientExtensionsInput {
99            transport_parameters: Some(match version {
100                Version::V1 | Version::V2 => TransportParameters::Quic(Payload::new(params)),
101            }),
102
103            ..ClientExtensionsInput::from_alpn(alpn_protocols)
104        };
105
106        let mut quic = Quic {
107            version,
108            ..Quic::default()
109        };
110
111        let inner = ConnectionCore::for_client(
112            config,
113            name,
114            exts,
115            Some(&mut quic),
116            Protocol::Quic(version),
117        )?;
118
119        Ok(Self {
120            inner: ConnectionCommon::new(inner, quic),
121        })
122    }
123
124    /// Return the FIPS validation status of the connection.
125    pub fn fips(&self) -> FipsStatus {
126        self.inner.fips
127    }
128
129    /// Returns True if the server signalled it will process early data.
130    ///
131    /// If you sent early data and this returns false at the end of the
132    /// handshake then the server will not process the data.  This
133    /// is not an error, but you may wish to resend the data.
134    pub fn is_early_data_accepted(&self) -> bool {
135        self.inner.core.is_early_data_accepted()
136    }
137
138    /// Returns the number of TLS1.3 tickets that have been received.
139    pub fn tls13_tickets_received(&self) -> u32 {
140        self.inner
141            .core
142            .common
143            .recv
144            .tls13_tickets_received
145    }
146
147    /// Returns an object that can derive key material from the agreed connection secrets.
148    ///
149    /// See [RFC5705][] for more details on what this is for.
150    ///
151    /// This function can be called at most once per connection.
152    ///
153    /// This function will error:
154    ///
155    /// - if called prior to the handshake completing; (check with
156    ///   [`CommonState::is_handshaking`] first).
157    /// - if called more than once per connection.
158    ///
159    /// [RFC5705]: https://datatracker.ietf.org/doc/html/rfc5705
160    pub fn exporter(&mut self) -> Result<KeyingMaterialExporter, Error> {
161        self.inner.core.exporter()
162    }
163}
164
165impl Connection for ClientConnection {
166    fn quic_transport_parameters(&self) -> Option<&[u8]> {
167        self.inner.quic_transport_parameters()
168    }
169
170    fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
171        self.inner.zero_rtt_keys()
172    }
173
174    fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
175        self.inner.read_hs(plaintext)
176    }
177
178    fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
179        self.inner.write_hs(buf)
180    }
181
182    fn is_handshaking(&self) -> bool {
183        self.inner.is_handshaking()
184    }
185}
186
187impl Deref for ClientConnection {
188    type Target = ConnectionOutputs;
189
190    fn deref(&self) -> &Self::Target {
191        &self.inner
192    }
193}
194
195impl Debug for ClientConnection {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        f.debug_struct("quic::ClientConnection")
198            .finish_non_exhaustive()
199    }
200}
201
202/// A QUIC server connection.
203pub struct ServerConnection {
204    inner: ConnectionCommon<ServerSide>,
205}
206
207impl ServerConnection {
208    /// Make a new QUIC ServerConnection.
209    ///
210    /// This differs from `ServerConnection::new()` in that it takes an extra `params` argument,
211    /// which contains the TLS-encoded transport parameters to send.
212    pub fn new(
213        config: Arc<ServerConfig>,
214        version: Version,
215        params: Vec<u8>,
216    ) -> Result<Self, Error> {
217        check_server_config(&config)?;
218        let exts = ServerExtensionsInput {
219            transport_parameters: Some(match version {
220                Version::V1 | Version::V2 => TransportParameters::Quic(Payload::new(params)),
221            }),
222        };
223
224        let core = ConnectionCore::for_server(config, exts, Protocol::Quic(version))?;
225        let inner = ConnectionCommon::new(
226            core,
227            Quic {
228                version,
229                ..Quic::default()
230            },
231        );
232        Ok(Self { inner })
233    }
234
235    /// Return the FIPS validation status of the connection.
236    pub fn fips(&self) -> FipsStatus {
237        self.inner.fips
238    }
239
240    /// Retrieves the server name, if any, used to select the certificate and
241    /// private key.
242    ///
243    /// This returns `None` until some time after the client's server name indication
244    /// (SNI) extension value is processed during the handshake. It will never be
245    /// `None` when the connection is ready to send or process application data,
246    /// unless the client does not support SNI.
247    ///
248    /// This is useful for application protocols that need to enforce that the
249    /// server name matches an application layer protocol hostname. For
250    /// example, HTTP/1.1 servers commonly expect the `Host:` header field of
251    /// every request on a connection to match the hostname in the SNI extension
252    /// when the client provides the SNI extension.
253    ///
254    /// The server name is also used to match sessions during session resumption.
255    pub fn server_name(&self) -> Option<&DnsName<'_>> {
256        self.inner.core.side.server_name()
257    }
258
259    /// Set the resumption data to embed in future resumption tickets supplied to the client.
260    ///
261    /// Defaults to the empty byte string. Must be less than 2^15 bytes to allow room for other
262    /// data. Should be called while `is_handshaking` returns true to ensure all transmitted
263    /// resumption tickets are affected (otherwise an error will be returned).
264    ///
265    /// Integrity will be assured by rustls, but the data will be visible to the client. If secrecy
266    /// from the client is desired, encrypt the data separately.
267    pub fn set_resumption_data(&mut self, resumption_data: &[u8]) -> Result<(), Error> {
268        assert!(resumption_data.len() < 2usize.pow(15));
269        match &mut self.inner.core.state {
270            Ok(st) => st.set_resumption_data(resumption_data),
271            Err(e) => Err(e.clone()),
272        }
273    }
274
275    /// Retrieves the resumption data supplied by the client, if any.
276    ///
277    /// Returns `Some` if and only if a valid resumption ticket has been received from the client.
278    pub fn received_resumption_data(&self) -> Option<&[u8]> {
279        self.inner
280            .core
281            .side
282            .received_resumption_data()
283    }
284
285    /// Returns an object that can derive key material from the agreed connection secrets.
286    ///
287    /// See [RFC5705][] for more details on what this is for.
288    ///
289    /// This function can be called at most once per connection.
290    ///
291    /// This function will error:
292    ///
293    /// - if called prior to the handshake completing; (check with
294    ///   [`CommonState::is_handshaking`] first).
295    /// - if called more than once per connection.
296    ///
297    /// [RFC5705]: https://datatracker.ietf.org/doc/html/rfc5705
298    pub fn exporter(&mut self) -> Result<KeyingMaterialExporter, Error> {
299        self.inner.core.exporter()
300    }
301}
302
303impl Connection for ServerConnection {
304    fn quic_transport_parameters(&self) -> Option<&[u8]> {
305        self.inner.quic_transport_parameters()
306    }
307
308    fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
309        self.inner.zero_rtt_keys()
310    }
311
312    fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
313        self.inner.read_hs(plaintext)
314    }
315
316    fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
317        self.inner.write_hs(buf)
318    }
319
320    fn is_handshaking(&self) -> bool {
321        self.inner.is_handshaking()
322    }
323}
324
325impl Deref for ServerConnection {
326    type Target = ConnectionOutputs;
327
328    fn deref(&self) -> &Self::Target {
329        &self.inner
330    }
331}
332
333impl Debug for ServerConnection {
334    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335        f.debug_struct("quic::ServerConnection")
336            .finish_non_exhaustive()
337    }
338}
339
340/// A QUIC server-side acceptor.
341///
342/// `Acceptor` allows callers to choose a [`ServerConfig`] after reading the
343/// [`ClientHello`] of an incoming QUIC connection.
344pub struct Acceptor {
345    inner: Option<ConnectionCommon<ServerSide>>,
346}
347
348impl Acceptor {
349    /// Make a new QUIC acceptor.
350    pub fn new(version: Version) -> Self {
351        Self {
352            inner: Some(ConnectionCommon::new(
353                ConnectionCore::for_acceptor(Protocol::Quic(version)),
354                Quic {
355                    version,
356                    ..Quic::default()
357                },
358            )),
359        }
360    }
361
362    /// Consume unencrypted TLS handshake data.
363    ///
364    /// The plaintext should be ordered QUIC CRYPTO stream data for one encryption level.
365    ///
366    /// Handshake data obtained from separate encryption levels should be supplied in separate calls.
367    pub fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
368        match &mut self.inner {
369            Some(conn) => conn.read_hs(plaintext),
370            None => Err(ApiMisuse::AcceptorPolledAfterCompletion.into()),
371        }
372    }
373
374    /// Check if a `ClientHello` message has been received.
375    ///
376    /// Returns `Ok(None)` if the complete `ClientHello` has not yet been received.
377    /// Supply more handshake data with [`Acceptor::read_hs()`] and call this function again.
378    ///
379    /// Returns `Ok(Some(accepted))` if the connection has been accepted. Call
380    /// [`Accepted::into_connection()`] to continue. Do not call this function again.
381    pub fn accept(&mut self) -> Result<Option<Accepted>, Error> {
382        let Some(mut connection) = self.inner.take() else {
383            return Err(ApiMisuse::AcceptorPolledAfterCompletion.into());
384        };
385
386        const MISUSED: Error = Error::Unreachable("Accepted misused state");
387        let state = mem::replace(&mut connection.core.state, Err(MISUSED))?;
388
389        Ok(match state {
390            ServerState::ChooseConfig(choose_config) => Some(Accepted {
391                connection,
392                choose_config,
393            }),
394            state => {
395                connection.core.state = Ok(state);
396                self.inner = Some(connection);
397                None
398            }
399        })
400    }
401}
402
403/// Represents a `ClientHello` message received through the [`Acceptor`].
404///
405/// Contains the state required to resume the connection through
406/// [`Accepted::into_connection()`].
407pub struct Accepted {
408    // invariant: `connection.core.state` is `Err(_)` and requires restoring
409    connection: ConnectionCommon<ServerSide>,
410    choose_config: Box<ChooseConfig>,
411}
412
413impl Accepted {
414    /// Get the [`ClientHello`] for this connection.
415    pub fn client_hello(&self) -> ClientHello<'_> {
416        self.choose_config.client_hello()
417    }
418
419    /// Convert the [`Accepted`] into a [`ServerConnection`].
420    ///
421    /// Takes the state returned from [`Acceptor::accept()`], the [`ServerConfig`]
422    /// that should be used for the session, and the TLS-encoded QUIC transport
423    /// parameters to send. Returns an error if configuration-dependent validation
424    /// of the received `ClientHello` message fails.
425    pub fn into_connection(
426        mut self,
427        config: Arc<ServerConfig>,
428        params: Vec<u8>,
429    ) -> Result<ServerConnection, Error> {
430        check_server_config(&config)?;
431
432        self.connection.core.accepted(
433            self.choose_config,
434            ServerExtensionsInput {
435                transport_parameters: Some(match self.connection.quic.version {
436                    Version::V1 | Version::V2 => TransportParameters::Quic(Payload::new(params)),
437                }),
438            },
439            Some(&mut self.connection.quic),
440            config,
441        )?;
442
443        Ok(ServerConnection {
444            inner: self.connection,
445        })
446    }
447}
448
449impl Debug for Accepted {
450    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451        f.debug_struct("quic::Accepted")
452            .finish_non_exhaustive()
453    }
454}
455
456fn check_server_config(config: &ServerConfig) -> Result<(), Error> {
457    let suites = &config.provider.tls13_cipher_suites;
458    if suites.is_empty() {
459        return Err(ApiMisuse::QuicRequiresTls13Support.into());
460    }
461
462    if !suites
463        .iter()
464        .any(|scs| scs.quic.is_some())
465    {
466        return Err(ApiMisuse::NoQuicCompatibleCipherSuites.into());
467    }
468
469    if config.max_early_data_size != 0 && config.max_early_data_size != 0xffff_ffff {
470        return Err(ApiMisuse::QuicRestrictsMaxEarlyDataSize.into());
471    }
472
473    Ok(())
474}
475
476/// A shared interface for QUIC connections.
477struct ConnectionCommon<Side: SideData> {
478    core: ConnectionCore<Side>,
479    deframer_buffer: VecInput,
480    quic: Quic,
481}
482
483impl<Side: SideData> ConnectionCommon<Side> {
484    fn new(core: ConnectionCore<Side>, quic: Quic) -> Self {
485        Self {
486            core,
487            deframer_buffer: VecInput::default(),
488            quic,
489        }
490    }
491
492    fn quic_transport_parameters(&self) -> Option<&[u8]> {
493        self.quic
494            .params
495            .as_ref()
496            .map(|v| v.as_ref())
497    }
498
499    fn zero_rtt_keys(&self) -> Option<DirectionalKeys> {
500        let suite = self
501            .core
502            .common
503            .negotiated_cipher_suite()
504            .and_then(|suite| match suite {
505                SupportedCipherSuite::Tls13(suite) => Some(suite),
506                _ => None,
507            })?;
508
509        Some(DirectionalKeys::new(
510            suite,
511            suite.quic?,
512            self.quic.early_secret.as_ref()?,
513            self.quic.version,
514        ))
515    }
516
517    fn read_hs(&mut self, plaintext: &[u8]) -> Result<(), Error> {
518        let range = self.deframer_buffer.extend(plaintext);
519
520        let deframer = &mut self.core.common.recv.deframer;
521        deframer.add_processed(range.len());
522        deframer.input_message(
523            EncodedMessage {
524                typ: ContentType::Handshake,
525                version: ProtocolVersion::TLSv1_3,
526                payload: &self.deframer_buffer.filled()[range.start..range.end],
527            },
528            range,
529        );
530
531        self.core
532            .common
533            .recv
534            .deframer
535            .coalesce(self.deframer_buffer.filled_mut())?;
536
537        self.core
538            .process_new_packets(&mut self.deframer_buffer, Some(&mut self.quic))?;
539
540        Ok(())
541    }
542
543    fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
544        self.quic.write_hs(buf)
545    }
546}
547
548impl<Side: SideData> Deref for ConnectionCommon<Side> {
549    type Target = CommonState;
550
551    fn deref(&self) -> &Self::Target {
552        &self.core.common
553    }
554}
555
556impl<Side: SideData> DerefMut for ConnectionCommon<Side> {
557    fn deref_mut(&mut self) -> &mut Self::Target {
558        &mut self.core.common
559    }
560}
561
562#[derive(Default)]
563pub(crate) struct Quic {
564    pub(crate) version: Version,
565    /// QUIC transport parameters received from the peer during the handshake
566    pub(crate) params: Option<Vec<u8>>,
567    pub(crate) hs_queue: VecDeque<(bool, Vec<u8>)>,
568    pub(crate) early_secret: Option<OkmBlock>,
569    pub(crate) hs_secrets: Option<Secrets>,
570    pub(crate) traffic_secrets: Option<Secrets>,
571    /// Whether keys derived from traffic_secrets have been passed to the QUIC implementation
572    pub(crate) returned_traffic_keys: bool,
573}
574
575impl Quic {
576    pub(crate) fn send_msg(&mut self, m: Message<'_>, must_encrypt: bool) {
577        if let MessagePayload::Alert(_) = m.payload {
578            // alerts are sent out-of-band in QUIC mode
579            return;
580        }
581
582        debug_assert!(
583            matches!(
584                m.payload,
585                MessagePayload::Handshake { .. } | MessagePayload::HandshakeFlight(_)
586            ),
587            "QUIC uses TLS for the cryptographic handshake only"
588        );
589        let mut bytes = Vec::new();
590        m.payload.encode(&mut bytes);
591        self.hs_queue
592            .push_back((must_encrypt, bytes));
593    }
594
595    pub(crate) fn write_hs(&mut self, buf: &mut Vec<u8>) -> Option<KeyChange> {
596        while let Some((_, msg)) = self.hs_queue.pop_front() {
597            buf.extend_from_slice(&msg);
598            if let Some(&(true, _)) = self.hs_queue.front() {
599                if self.hs_secrets.is_some() {
600                    // Allow the caller to switch keys before proceeding.
601                    break;
602                }
603            }
604        }
605
606        if let Some(secrets) = self.hs_secrets.take() {
607            return Some(KeyChange::Handshake {
608                keys: Keys::new(&secrets),
609            });
610        }
611
612        if let Some(mut secrets) = self.traffic_secrets.take() {
613            if !self.returned_traffic_keys {
614                self.returned_traffic_keys = true;
615                let keys = Keys::new(&secrets);
616                secrets.update();
617                return Some(KeyChange::OneRtt {
618                    keys,
619                    next: secrets,
620                });
621            }
622        }
623
624        None
625    }
626}
627
628impl QuicOutput for Quic {
629    fn transport_parameters(&mut self, params: Vec<u8>) {
630        self.params = Some(params);
631    }
632
633    fn early_secret(&mut self, secret: Option<OkmBlock>) {
634        self.early_secret = secret;
635    }
636
637    fn handshake_secrets(
638        &mut self,
639        client_secret: OkmBlock,
640        server_secret: OkmBlock,
641        suite: &'static Tls13CipherSuite,
642        quic: &'static dyn Algorithm,
643        side: Side,
644    ) {
645        self.hs_secrets = Some(Secrets::new(
646            client_secret,
647            server_secret,
648            suite,
649            quic,
650            side,
651            self.version,
652        ));
653    }
654
655    fn traffic_secrets(
656        &mut self,
657        client_secret: OkmBlock,
658        server_secret: OkmBlock,
659        suite: &'static Tls13CipherSuite,
660        quic: &'static dyn Algorithm,
661        side: Side,
662    ) {
663        self.traffic_secrets = Some(Secrets::new(
664            client_secret,
665            server_secret,
666            suite,
667            quic,
668            side,
669            self.version,
670        ));
671    }
672
673    fn send_msg(&mut self, m: Message<'_>, must_encrypt: bool) {
674        self.send_msg(m, must_encrypt);
675    }
676}
677
678pub(crate) trait QuicOutput {
679    fn transport_parameters(&mut self, params: Vec<u8>);
680
681    fn early_secret(&mut self, secret: Option<OkmBlock>);
682
683    fn handshake_secrets(
684        &mut self,
685        client_secret: OkmBlock,
686        server_secret: OkmBlock,
687        suite: &'static Tls13CipherSuite,
688        quic: &'static dyn Algorithm,
689        side: Side,
690    );
691
692    fn traffic_secrets(
693        &mut self,
694        client_secret: OkmBlock,
695        server_secret: OkmBlock,
696        suite: &'static Tls13CipherSuite,
697        quic: &'static dyn Algorithm,
698        side: Side,
699    );
700
701    fn send_msg(&mut self, m: Message<'_>, must_encrypt: bool);
702}
703
704/// Secrets used to encrypt/decrypt traffic
705#[derive(Clone)]
706pub struct Secrets {
707    /// Secret used to encrypt packets transmitted by the client
708    pub(crate) client: OkmBlock,
709    /// Secret used to encrypt packets transmitted by the server
710    pub(crate) server: OkmBlock,
711    /// Cipher suite used with these secrets
712    suite: &'static Tls13CipherSuite,
713    quic: &'static dyn Algorithm,
714    side: Side,
715    version: Version,
716}
717
718impl Secrets {
719    pub(crate) fn new(
720        client: OkmBlock,
721        server: OkmBlock,
722        suite: &'static Tls13CipherSuite,
723        quic: &'static dyn Algorithm,
724        side: Side,
725        version: Version,
726    ) -> Self {
727        Self {
728            client,
729            server,
730            suite,
731            quic,
732            side,
733            version,
734        }
735    }
736
737    /// Derive the next set of packet keys
738    pub fn next_packet_keys(&mut self) -> PacketKeySet {
739        let keys = PacketKeySet::new(self);
740        self.update();
741        keys
742    }
743
744    pub(crate) fn update(&mut self) {
745        self.client = hkdf_expand_label_block(
746            self.suite
747                .hkdf_provider
748                .expander_for_okm(&self.client)
749                .as_ref(),
750            self.version.key_update_label(),
751            &[],
752        );
753        self.server = hkdf_expand_label_block(
754            self.suite
755                .hkdf_provider
756                .expander_for_okm(&self.server)
757                .as_ref(),
758            self.version.key_update_label(),
759            &[],
760        );
761    }
762
763    fn local_remote(&self) -> (&OkmBlock, &OkmBlock) {
764        match self.side {
765            Side::Client => (&self.client, &self.server),
766            Side::Server => (&self.server, &self.client),
767        }
768    }
769}
770
771/// Keys used to communicate in a single direction
772#[expect(clippy::exhaustive_structs)]
773pub struct DirectionalKeys {
774    /// Encrypts or decrypts a packet's headers
775    pub header: Box<dyn HeaderProtectionKey>,
776    /// Encrypts or decrypts the payload of a packet
777    pub packet: Box<dyn PacketKey>,
778}
779
780impl DirectionalKeys {
781    pub(crate) fn new(
782        suite: &'static Tls13CipherSuite,
783        quic: &'static dyn Algorithm,
784        secret: &OkmBlock,
785        version: Version,
786    ) -> Self {
787        let builder = KeyBuilder::new(secret, version, quic, suite.hkdf_provider);
788        Self {
789            header: builder.header_protection_key(),
790            packet: builder.packet_key(),
791        }
792    }
793}
794
795/// All AEADs we support have 16-byte tags.
796const TAG_LEN: usize = 16;
797
798/// Authentication tag from an AEAD seal operation.
799pub struct Tag([u8; TAG_LEN]);
800
801impl From<&[u8]> for Tag {
802    fn from(value: &[u8]) -> Self {
803        let mut array = [0u8; TAG_LEN];
804        array.copy_from_slice(value);
805        Self(array)
806    }
807}
808
809impl AsRef<[u8]> for Tag {
810    fn as_ref(&self) -> &[u8] {
811        &self.0
812    }
813}
814
815/// How a `Tls13CipherSuite` generates `PacketKey`s and `HeaderProtectionKey`s.
816pub trait Algorithm: Send + Sync {
817    /// Produce a `PacketKey` encrypter/decrypter for this suite.
818    ///
819    /// `suite` is the entire suite this `Algorithm` appeared in.
820    /// `key` and `iv` is the key material to use.
821    fn packet_key(&self, key: AeadKey, iv: Iv) -> Box<dyn PacketKey>;
822
823    /// Produce a `HeaderProtectionKey` encrypter/decrypter for this suite.
824    ///
825    /// `key` is the key material, which is `aead_key_len()` bytes in length.
826    fn header_protection_key(&self, key: AeadKey) -> Box<dyn HeaderProtectionKey>;
827
828    /// The length in bytes of keys for this Algorithm.
829    ///
830    /// This controls the size of `AeadKey`s presented to `packet_key()` and `header_protection_key()`.
831    fn aead_key_len(&self) -> usize;
832
833    /// Whether this algorithm is FIPS-approved.
834    fn fips(&self) -> FipsStatus {
835        FipsStatus::Unvalidated
836    }
837}
838
839/// A QUIC header protection key
840pub trait HeaderProtectionKey: Send + Sync {
841    /// Adds QUIC Header Protection.
842    ///
843    /// `sample` must contain the sample of encrypted payload; see
844    /// [Header Protection Sample].
845    ///
846    /// `first` must reference the first byte of the header, referred to as
847    /// `packet[0]` in [Header Protection Application].
848    ///
849    /// `packet_number` must reference the Packet Number field; this is
850    /// `packet[pn_offset:pn_offset+pn_length]` in [Header Protection Application].
851    ///
852    /// Returns an error without modifying anything if `sample` is not
853    /// the correct length (see [Header Protection Sample] and [`Self::sample_len()`]),
854    /// or `packet_number` is longer than allowed (see [Packet Number Encoding and Decoding]).
855    ///
856    /// Otherwise, `first` and `packet_number` will have the header protection added.
857    ///
858    /// [Header Protection Application]: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.1
859    /// [Header Protection Sample]: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.2
860    /// [Packet Number Encoding and Decoding]: https://datatracker.ietf.org/doc/html/rfc9000#section-17.1
861    fn encrypt_in_place(
862        &self,
863        sample: &[u8],
864        first: &mut u8,
865        packet_number: &mut [u8],
866    ) -> Result<(), Error>;
867
868    /// Removes QUIC Header Protection.
869    ///
870    /// `sample` must contain the sample of encrypted payload; see
871    /// [Header Protection Sample].
872    ///
873    /// `first` must reference the first byte of the header, referred to as
874    /// `packet[0]` in [Header Protection Application].
875    ///
876    /// `packet_number` must reference the Packet Number field; this is
877    /// `packet[pn_offset:pn_offset+pn_length]` in [Header Protection Application].
878    ///
879    /// Returns an error without modifying anything if `sample` is not
880    /// the correct length (see [Header Protection Sample] and [`Self::sample_len()`]),
881    /// or `packet_number` is longer than allowed (see
882    /// [Packet Number Encoding and Decoding]).
883    ///
884    /// Otherwise, `first` and `packet_number` will have the header protection removed.
885    ///
886    /// [Header Protection Application]: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.1
887    /// [Header Protection Sample]: https://datatracker.ietf.org/doc/html/rfc9001#section-5.4.2
888    /// [Packet Number Encoding and Decoding]: https://datatracker.ietf.org/doc/html/rfc9000#section-17.1
889    fn decrypt_in_place(
890        &self,
891        sample: &[u8],
892        first: &mut u8,
893        packet_number: &mut [u8],
894    ) -> Result<(), Error>;
895
896    /// Expected sample length for the key's algorithm
897    fn sample_len(&self) -> usize;
898}
899
900/// Keys to encrypt or decrypt the payload of a packet
901pub trait PacketKey: Send + Sync {
902    /// Encrypt a QUIC packet
903    ///
904    /// Takes a `packet_number` and optional `path_id`, used to derive the nonce; the packet
905    /// `header`, which is used as the additional authenticated data; and the `payload`. The
906    /// authentication tag is returned if encryption succeeds.
907    ///
908    /// Fails if and only if the payload is longer than allowed by the cipher suite's AEAD algorithm.
909    ///
910    /// When provided, the `path_id` is used for multipath encryption as described in
911    /// <https://www.ietf.org/archive/id/draft-ietf-quic-multipath-15.html#section-2.4>.
912    fn encrypt_in_place(
913        &self,
914        packet_number: u64,
915        header: &[u8],
916        payload: &mut [u8],
917        path_id: Option<u32>,
918    ) -> Result<Tag, Error>;
919
920    /// Decrypt a QUIC packet
921    ///
922    /// Takes a `packet_number` and optional `path_id`, used to derive the nonce; the packet
923    /// `header`, which is used as the additional authenticated data, and the `payload`, which
924    /// includes the authentication tag.
925    ///
926    /// On success, returns the slice of `payload` containing the decrypted data.
927    ///
928    /// When provided, the `path_id` is used for multipath encryption as described in
929    /// <https://www.ietf.org/archive/id/draft-ietf-quic-multipath-15.html#section-2.4>.
930    fn decrypt_in_place<'a>(
931        &self,
932        packet_number: u64,
933        header: &[u8],
934        payload: &'a mut [u8],
935        path_id: Option<u32>,
936    ) -> Result<&'a [u8], Error>;
937
938    /// Tag length for the underlying AEAD algorithm
939    fn tag_len(&self) -> usize;
940
941    /// Number of QUIC messages that can be safely encrypted with a single key of this type.
942    ///
943    /// Once a `MessageEncrypter` produced for this suite has encrypted more than
944    /// `confidentiality_limit` messages, an attacker gains an advantage in distinguishing it
945    /// from an ideal pseudorandom permutation (PRP).
946    ///
947    /// This is to be set on the assumption that messages are maximally sized --
948    /// 2 ** 16. For non-QUIC TCP connections see [`CipherSuiteCommon::confidentiality_limit`][csc-limit].
949    ///
950    /// [csc-limit]: crate::crypto::CipherSuiteCommon::confidentiality_limit
951    fn confidentiality_limit(&self) -> u64;
952
953    /// Number of QUIC messages that can be safely decrypted with a single key of this type
954    ///
955    /// Once a `MessageDecrypter` produced for this suite has failed to decrypt `integrity_limit`
956    /// messages, an attacker gains an advantage in forging messages.
957    ///
958    /// This is not relevant for TLS over TCP (which is also implemented in this crate)
959    /// because a single failed decryption is fatal to the connection.
960    /// However, this quantity is used by QUIC.
961    fn integrity_limit(&self) -> u64;
962}
963
964/// Packet protection keys for bidirectional 1-RTT communication
965#[expect(clippy::exhaustive_structs)]
966pub struct PacketKeySet {
967    /// Encrypts outgoing packets
968    pub local: Box<dyn PacketKey>,
969    /// Decrypts incoming packets
970    pub remote: Box<dyn PacketKey>,
971}
972
973impl PacketKeySet {
974    fn new(secrets: &Secrets) -> Self {
975        let (local, remote) = secrets.local_remote();
976        let (version, alg, hkdf) = (secrets.version, secrets.quic, secrets.suite.hkdf_provider);
977        Self {
978            local: KeyBuilder::new(local, version, alg, hkdf).packet_key(),
979            remote: KeyBuilder::new(remote, version, alg, hkdf).packet_key(),
980        }
981    }
982}
983
984/// Helper for building QUIC packet and header protection keys
985pub struct KeyBuilder<'a> {
986    expander: Box<dyn HkdfExpander>,
987    version: Version,
988    alg: &'a dyn Algorithm,
989}
990
991impl<'a> KeyBuilder<'a> {
992    /// Create a new KeyBuilder
993    pub fn new(
994        secret: &OkmBlock,
995        version: Version,
996        alg: &'a dyn Algorithm,
997        hkdf: &'a dyn Hkdf,
998    ) -> Self {
999        Self {
1000            expander: hkdf.expander_for_okm(secret),
1001            version,
1002            alg,
1003        }
1004    }
1005
1006    /// Derive packet keys
1007    pub fn packet_key(&self) -> Box<dyn PacketKey> {
1008        let aead_key_len = self.alg.aead_key_len();
1009        let packet_key = hkdf_expand_label_aead_key(
1010            self.expander.as_ref(),
1011            aead_key_len,
1012            self.version.packet_key_label(),
1013            &[],
1014        );
1015
1016        let packet_iv =
1017            hkdf_expand_label(self.expander.as_ref(), self.version.packet_iv_label(), &[]);
1018        self.alg
1019            .packet_key(packet_key, packet_iv)
1020    }
1021
1022    /// Derive header protection keys
1023    pub fn header_protection_key(&self) -> Box<dyn HeaderProtectionKey> {
1024        let header_key = hkdf_expand_label_aead_key(
1025            self.expander.as_ref(),
1026            self.alg.aead_key_len(),
1027            self.version.header_key_label(),
1028            &[],
1029        );
1030        self.alg
1031            .header_protection_key(header_key)
1032    }
1033}
1034
1035/// Produces QUIC initial keys from a TLS 1.3 ciphersuite and a QUIC key generation algorithm.
1036#[non_exhaustive]
1037#[derive(Clone, Copy)]
1038pub struct Suite {
1039    /// The TLS 1.3 ciphersuite used to derive keys.
1040    pub suite: &'static Tls13CipherSuite,
1041    /// The QUIC key generation algorithm used to derive keys.
1042    pub quic: &'static dyn Algorithm,
1043}
1044
1045impl Suite {
1046    /// Produce a set of initial keys given the connection ID, side and version
1047    pub fn keys(&self, client_dst_connection_id: &[u8], side: Side, version: Version) -> Keys {
1048        Keys::initial(
1049            version,
1050            self.suite,
1051            self.quic,
1052            client_dst_connection_id,
1053            side,
1054        )
1055    }
1056}
1057
1058/// Complete set of keys used to communicate with the peer
1059#[expect(clippy::exhaustive_structs)]
1060pub struct Keys {
1061    /// Encrypts outgoing packets
1062    pub local: DirectionalKeys,
1063    /// Decrypts incoming packets
1064    pub remote: DirectionalKeys,
1065}
1066
1067impl Keys {
1068    /// Construct keys for use with initial packets
1069    pub fn initial(
1070        version: Version,
1071        suite: &'static Tls13CipherSuite,
1072        quic: &'static dyn Algorithm,
1073        client_dst_connection_id: &[u8],
1074        side: Side,
1075    ) -> Self {
1076        const CLIENT_LABEL: &[u8] = b"client in";
1077        const SERVER_LABEL: &[u8] = b"server in";
1078        let salt = version.initial_salt();
1079        let hs_secret = suite
1080            .hkdf_provider
1081            .extract_from_secret(Some(salt), client_dst_connection_id);
1082
1083        let secrets = Secrets {
1084            client: hkdf_expand_label_block(hs_secret.as_ref(), CLIENT_LABEL, &[]),
1085            server: hkdf_expand_label_block(hs_secret.as_ref(), SERVER_LABEL, &[]),
1086            suite,
1087            quic,
1088            side,
1089            version,
1090        };
1091        Self::new(&secrets)
1092    }
1093
1094    fn new(secrets: &Secrets) -> Self {
1095        let (local, remote) = secrets.local_remote();
1096        Self {
1097            local: DirectionalKeys::new(secrets.suite, secrets.quic, local, secrets.version),
1098            remote: DirectionalKeys::new(secrets.suite, secrets.quic, remote, secrets.version),
1099        }
1100    }
1101}
1102
1103/// Key material for use in QUIC packet spaces
1104///
1105/// QUIC uses 4 different sets of keys (and progressive key updates for long-running connections):
1106///
1107/// * Initial: these can be created from [`Keys::initial()`]
1108/// * 0-RTT keys: can be retrieved from [`Connection::zero_rtt_keys()`]
1109/// * Handshake: these are returned from [`Connection::write_hs()`] after `ClientHello` and
1110///   `ServerHello` messages have been exchanged
1111/// * 1-RTT keys: these are returned from [`Connection::write_hs()`] after the handshake is done
1112///
1113/// Once the 1-RTT keys have been exchanged, either side may initiate a key update. Progressive
1114/// update keys can be obtained from the [`Secrets`] returned in [`KeyChange::OneRtt`]. Note that
1115/// only packet keys are updated by key updates; header protection keys remain the same.
1116#[expect(clippy::exhaustive_enums)]
1117pub enum KeyChange {
1118    /// Keys for the handshake space
1119    Handshake {
1120        /// Header and packet keys for the handshake space
1121        keys: Keys,
1122    },
1123    /// Keys for 1-RTT data
1124    OneRtt {
1125        /// Header and packet keys for 1-RTT data
1126        keys: Keys,
1127        /// Secrets to derive updated keys from
1128        next: Secrets,
1129    },
1130}
1131
1132/// QUIC protocol version
1133///
1134/// Governs version-specific behavior in the TLS layer
1135#[non_exhaustive]
1136#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
1137pub enum Version {
1138    /// First stable RFC
1139    #[default]
1140    V1,
1141    /// Anti-ossification variant of V1
1142    V2,
1143}
1144
1145impl Version {
1146    fn initial_salt(self) -> &'static [u8; 20] {
1147        match self {
1148            Self::V1 => &[
1149                // https://www.rfc-editor.org/rfc/rfc9001.html#name-initial-secrets
1150                0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
1151                0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
1152            ],
1153            Self::V2 => &[
1154                // https://tools.ietf.org/html/rfc9369.html#name-initial-salt
1155                0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26,
1156                0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9,
1157            ],
1158        }
1159    }
1160
1161    /// Key derivation label for packet keys.
1162    pub(crate) fn packet_key_label(&self) -> &'static [u8] {
1163        match self {
1164            Self::V1 => b"quic key",
1165            Self::V2 => b"quicv2 key",
1166        }
1167    }
1168
1169    /// Key derivation label for packet "IV"s.
1170    pub(crate) fn packet_iv_label(&self) -> &'static [u8] {
1171        match self {
1172            Self::V1 => b"quic iv",
1173            Self::V2 => b"quicv2 iv",
1174        }
1175    }
1176
1177    /// Key derivation for header keys.
1178    pub(crate) fn header_key_label(&self) -> &'static [u8] {
1179        match self {
1180            Self::V1 => b"quic hp",
1181            Self::V2 => b"quicv2 hp",
1182        }
1183    }
1184
1185    fn key_update_label(&self) -> &'static [u8] {
1186        match self {
1187            Self::V1 => b"quic ku",
1188            Self::V2 => b"quicv2 ku",
1189        }
1190    }
1191}
1192
1193#[cfg(all(test, any(target_arch = "aarch64", target_arch = "x86_64")))]
1194mod tests {
1195    use super::*;
1196    use crate::crypto::TLS13_TEST_SUITE;
1197    use crate::crypto::tls13::OkmBlock;
1198    use crate::quic::{HeaderProtectionKey, Secrets, Side, Version};
1199
1200    #[test]
1201    fn key_update_test_vector() {
1202        fn equal_okm(x: &OkmBlock, y: &OkmBlock) -> bool {
1203            x.as_ref() == y.as_ref()
1204        }
1205
1206        let mut secrets = Secrets {
1207            // Constant dummy values for reproducibility
1208            client: OkmBlock::new(
1209                &[
1210                    0xb8, 0x76, 0x77, 0x08, 0xf8, 0x77, 0x23, 0x58, 0xa6, 0xea, 0x9f, 0xc4, 0x3e,
1211                    0x4a, 0xdd, 0x2c, 0x96, 0x1b, 0x3f, 0x52, 0x87, 0xa6, 0xd1, 0x46, 0x7e, 0xe0,
1212                    0xae, 0xab, 0x33, 0x72, 0x4d, 0xbf,
1213                ][..],
1214            ),
1215            server: OkmBlock::new(
1216                &[
1217                    0x42, 0xdc, 0x97, 0x21, 0x40, 0xe0, 0xf2, 0xe3, 0x98, 0x45, 0xb7, 0x67, 0x61,
1218                    0x34, 0x39, 0xdc, 0x67, 0x58, 0xca, 0x43, 0x25, 0x9b, 0x87, 0x85, 0x06, 0x82,
1219                    0x4e, 0xb1, 0xe4, 0x38, 0xd8, 0x55,
1220                ][..],
1221            ),
1222            suite: TLS13_TEST_SUITE,
1223            quic: &FakeAlgorithm,
1224            side: Side::Client,
1225            version: Version::V1,
1226        };
1227        secrets.update();
1228
1229        assert!(equal_okm(
1230            &secrets.client,
1231            &OkmBlock::new(
1232                &[
1233                    0x42, 0xca, 0xc8, 0xc9, 0x1c, 0xd5, 0xeb, 0x40, 0x68, 0x2e, 0x43, 0x2e, 0xdf,
1234                    0x2d, 0x2b, 0xe9, 0xf4, 0x1a, 0x52, 0xca, 0x6b, 0x22, 0xd8, 0xe6, 0xcd, 0xb1,
1235                    0xe8, 0xac, 0xa9, 0x6, 0x1f, 0xce
1236                ][..]
1237            )
1238        ));
1239        assert!(equal_okm(
1240            &secrets.server,
1241            &OkmBlock::new(
1242                &[
1243                    0xeb, 0x7f, 0x5e, 0x2a, 0x12, 0x3f, 0x40, 0x7d, 0xb4, 0x99, 0xe3, 0x61, 0xca,
1244                    0xe5, 0x90, 0xd4, 0xd9, 0x92, 0xe1, 0x4b, 0x7a, 0xce, 0x3, 0xc2, 0x44, 0xe0,
1245                    0x42, 0x21, 0x15, 0xb6, 0xd3, 0x8a
1246                ][..]
1247            )
1248        ));
1249    }
1250
1251    struct FakeAlgorithm;
1252
1253    impl Algorithm for FakeAlgorithm {
1254        fn packet_key(&self, _key: AeadKey, _iv: Iv) -> Box<dyn PacketKey> {
1255            unimplemented!()
1256        }
1257
1258        fn header_protection_key(&self, _key: AeadKey) -> Box<dyn HeaderProtectionKey> {
1259            unimplemented!()
1260        }
1261
1262        fn aead_key_len(&self) -> usize {
1263            16
1264        }
1265    }
1266
1267    #[test]
1268    fn auto_traits() {
1269        fn assert_auto<T: Send + Sync>() {}
1270        assert_auto::<Box<dyn PacketKey>>();
1271        assert_auto::<Box<dyn HeaderProtectionKey>>();
1272    }
1273}