rustls/msgs/
handshake.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3#[cfg(feature = "log")]
4use alloc::string::String;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::{Deref, DerefMut};
8use core::{fmt, iter};
9
10use pki_types::{CertificateDer, DnsName};
11
12#[cfg(feature = "tls12")]
13use crate::crypto::ActiveKeyExchange;
14use crate::crypto::SecureRandom;
15use crate::enums::{
16    CertificateCompressionAlgorithm, CertificateType, CipherSuite, EchClientHelloType,
17    HandshakeType, ProtocolVersion, SignatureScheme,
18};
19use crate::error::InvalidMessage;
20#[cfg(feature = "tls12")]
21use crate::ffdhe_groups::FfdheGroup;
22use crate::log::warn;
23use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
24use crate::msgs::codec::{
25    self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement, TlsListIter,
26};
27use crate::msgs::enums::{
28    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
29    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
30    PskKeyExchangeMode, ServerNameType,
31};
32use crate::rand;
33use crate::sync::Arc;
34use crate::verify::DigitallySignedStruct;
35use crate::x509::wrap_in_sequence;
36
37/// Create a newtype wrapper around a given type.
38///
39/// This is used to create newtypes for the various TLS message types which is used to wrap
40/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
41/// anything other than access to the underlying bytes.
42macro_rules! wrapped_payload(
43  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
44    $(#[$comment])*
45    #[derive(Clone, Debug)]
46    $vis struct $name($inner$(<$inner_ty>)?);
47
48    impl From<Vec<u8>> for $name {
49        fn from(v: Vec<u8>) -> Self {
50            Self($inner::new(v))
51        }
52    }
53
54    impl AsRef<[u8]> for $name {
55        fn as_ref(&self) -> &[u8] {
56            self.0.0.as_slice()
57        }
58    }
59
60    impl Codec<'_> for $name {
61        fn encode(&self, bytes: &mut Vec<u8>) {
62            self.0.encode(bytes);
63        }
64
65        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
66            Ok(Self($inner::read(r)?))
67        }
68    }
69  }
70);
71
72#[derive(Clone, Copy, Eq, PartialEq)]
73pub(crate) struct Random(pub(crate) [u8; 32]);
74
75impl fmt::Debug for Random {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        super::base::hex(f, &self.0)
78    }
79}
80
81static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
82    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
83    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
84]);
85
86static ZERO_RANDOM: Random = Random([0u8; 32]);
87
88impl Codec<'_> for Random {
89    fn encode(&self, bytes: &mut Vec<u8>) {
90        bytes.extend_from_slice(&self.0);
91    }
92
93    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
94        let Some(bytes) = r.take(32) else {
95            return Err(InvalidMessage::MissingData("Random"));
96        };
97
98        let mut opaque = [0; 32];
99        opaque.clone_from_slice(bytes);
100        Ok(Self(opaque))
101    }
102}
103
104impl Random {
105    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
106        let mut data = [0u8; 32];
107        secure_random.fill(&mut data)?;
108        Ok(Self(data))
109    }
110}
111
112impl From<[u8; 32]> for Random {
113    #[inline]
114    fn from(bytes: [u8; 32]) -> Self {
115        Self(bytes)
116    }
117}
118
119#[derive(Copy, Clone)]
120pub(crate) struct SessionId {
121    len: usize,
122    data: [u8; 32],
123}
124
125impl fmt::Debug for SessionId {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        super::base::hex(f, &self.data[..self.len])
128    }
129}
130
131impl PartialEq for SessionId {
132    fn eq(&self, other: &Self) -> bool {
133        if self.len != other.len {
134            return false;
135        }
136
137        let mut diff = 0u8;
138        for i in 0..self.len {
139            diff |= self.data[i] ^ other.data[i];
140        }
141
142        diff == 0u8
143    }
144}
145
146impl Codec<'_> for SessionId {
147    fn encode(&self, bytes: &mut Vec<u8>) {
148        debug_assert!(self.len <= 32);
149        bytes.push(self.len as u8);
150        bytes.extend_from_slice(self.as_ref());
151    }
152
153    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
154        let len = u8::read(r)? as usize;
155        if len > 32 {
156            return Err(InvalidMessage::TrailingData("SessionID"));
157        }
158
159        let Some(bytes) = r.take(len) else {
160            return Err(InvalidMessage::MissingData("SessionID"));
161        };
162
163        let mut out = [0u8; 32];
164        out[..len].clone_from_slice(&bytes[..len]);
165        Ok(Self { data: out, len })
166    }
167}
168
169impl SessionId {
170    pub(crate) fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
171        let mut data = [0u8; 32];
172        secure_random.fill(&mut data)?;
173        Ok(Self { data, len: 32 })
174    }
175
176    pub(crate) fn empty() -> Self {
177        Self {
178            data: [0u8; 32],
179            len: 0,
180        }
181    }
182
183    #[cfg(feature = "tls12")]
184    pub(crate) fn is_empty(&self) -> bool {
185        self.len == 0
186    }
187}
188
189impl AsRef<[u8]> for SessionId {
190    fn as_ref(&self) -> &[u8] {
191        &self.data[..self.len]
192    }
193}
194
195#[derive(Clone, Debug, PartialEq)]
196pub struct UnknownExtension {
197    pub(crate) typ: ExtensionType,
198    pub(crate) payload: Payload<'static>,
199}
200
201impl UnknownExtension {
202    fn encode(&self, bytes: &mut Vec<u8>) {
203        self.payload.encode(bytes);
204    }
205
206    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
207        let payload = Payload::read(r).into_owned();
208        Self { typ, payload }
209    }
210}
211
212#[derive(Clone, Copy, Debug)]
213pub(crate) struct SupportedEcPointFormats {
214    pub(crate) uncompressed: bool,
215}
216
217impl Codec<'_> for SupportedEcPointFormats {
218    fn encode(&self, bytes: &mut Vec<u8>) {
219        let inner = LengthPrefixedBuffer::new(ECPointFormat::SIZE_LEN, bytes);
220
221        if self.uncompressed {
222            ECPointFormat::Uncompressed.encode(inner.buf);
223        }
224    }
225
226    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
227        let mut uncompressed = false;
228
229        for pf in TlsListIter::<ECPointFormat>::new(r)? {
230            if let ECPointFormat::Uncompressed = pf? {
231                uncompressed = true;
232            }
233        }
234
235        Ok(Self { uncompressed })
236    }
237}
238
239impl Default for SupportedEcPointFormats {
240    fn default() -> Self {
241        Self { uncompressed: true }
242    }
243}
244
245/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
246impl TlsListElement for ECPointFormat {
247    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
248        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
249    };
250}
251
252/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
253impl TlsListElement for NamedGroup {
254    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
255        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
256    };
257}
258
259/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
260impl TlsListElement for SignatureScheme {
261    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
262        empty_error: InvalidMessage::NoSignatureSchemes,
263    };
264}
265
266#[derive(Clone, Debug)]
267pub(crate) enum ServerNamePayload<'a> {
268    /// A successfully decoded value:
269    SingleDnsName(DnsName<'a>),
270
271    /// A DNS name which was actually an IP address
272    IpAddress,
273
274    /// A successfully decoded, but syntactically-invalid value.
275    Invalid,
276}
277
278impl ServerNamePayload<'_> {
279    fn into_owned(self) -> ServerNamePayload<'static> {
280        match self {
281            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
282            Self::IpAddress => ServerNamePayload::IpAddress,
283            Self::Invalid => ServerNamePayload::Invalid,
284        }
285    }
286
287    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
288    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
289        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
290    };
291}
292
293/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
294///
295/// This is possible because:
296///
297/// - the spec (RFC6066) disallows multiple names for a given name type
298/// - name types other than ServerNameType::HostName are not defined, and they and
299///   any data that follows them cannot be skipped over.
300impl<'a> Codec<'a> for ServerNamePayload<'a> {
301    fn encode(&self, bytes: &mut Vec<u8>) {
302        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
303
304        let ServerNamePayload::SingleDnsName(dns_name) = self else {
305            return;
306        };
307
308        ServerNameType::HostName.encode(server_name_list.buf);
309        let name_slice = dns_name.as_ref().as_bytes();
310        (name_slice.len() as u16).encode(server_name_list.buf);
311        server_name_list
312            .buf
313            .extend_from_slice(name_slice);
314    }
315
316    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
317        let mut found = None;
318
319        let len = Self::SIZE_LEN.read(r)?;
320        let mut sub = r.sub(len)?;
321
322        while sub.any_left() {
323            let typ = ServerNameType::read(&mut sub)?;
324
325            let payload = match typ {
326                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
327                _ => {
328                    // Consume remainder of extension bytes.  Since the length of the item
329                    // is an unknown encoding, we cannot continue.
330                    sub.rest();
331                    break;
332                }
333            };
334
335            // "The ServerNameList MUST NOT contain more than one name of
336            // the same name_type." - RFC6066
337            if found.is_some() {
338                warn!("Illegal SNI extension: duplicate host_name received");
339                return Err(InvalidMessage::InvalidServerName);
340            }
341
342            found = match payload {
343                HostNamePayload::HostName(dns_name) => {
344                    Some(Self::SingleDnsName(dns_name.to_owned()))
345                }
346
347                HostNamePayload::IpAddress(_invalid) => {
348                    warn!(
349                        "Illegal SNI extension: ignoring IP address presented as hostname ({_invalid:?})"
350                    );
351                    Some(Self::IpAddress)
352                }
353
354                HostNamePayload::Invalid(_invalid) => {
355                    warn!(
356                        "Illegal SNI hostname received {:?}",
357                        String::from_utf8_lossy(&_invalid.0)
358                    );
359                    Some(Self::Invalid)
360                }
361            };
362        }
363
364        Ok(found.unwrap_or(Self::Invalid))
365    }
366}
367
368impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
369    fn from(value: &DnsName<'a>) -> Self {
370        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
371    }
372}
373
374#[derive(Clone, Debug)]
375pub(crate) enum HostNamePayload {
376    HostName(DnsName<'static>),
377    IpAddress(PayloadU16<NonEmpty>),
378    Invalid(PayloadU16<NonEmpty>),
379}
380
381impl HostNamePayload {
382    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
383        use pki_types::ServerName;
384        let raw = PayloadU16::<NonEmpty>::read(r)?;
385
386        match ServerName::try_from(raw.0.as_slice()) {
387            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
388            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
389            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
390        }
391    }
392}
393
394wrapped_payload!(
395    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
396    pub(crate) struct ProtocolName, PayloadU8<NonEmpty>,
397);
398
399impl PartialEq for ProtocolName {
400    fn eq(&self, other: &Self) -> bool {
401        self.0 == other.0
402    }
403}
404
405impl Deref for ProtocolName {
406    type Target = [u8];
407
408    fn deref(&self) -> &Self::Target {
409        self.as_ref()
410    }
411}
412
413/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
414impl TlsListElement for ProtocolName {
415    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
416        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
417    };
418}
419
420/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
421#[derive(Clone, Debug)]
422pub(crate) struct SingleProtocolName(ProtocolName);
423
424impl SingleProtocolName {
425    pub(crate) fn new(single: ProtocolName) -> Self {
426        Self(single)
427    }
428
429    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
430        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
431    };
432}
433
434impl Codec<'_> for SingleProtocolName {
435    fn encode(&self, bytes: &mut Vec<u8>) {
436        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
437        self.0.encode(body.buf);
438    }
439
440    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
441        let len = Self::SIZE_LEN.read(reader)?;
442        let mut sub = reader.sub(len)?;
443
444        let item = ProtocolName::read(&mut sub)?;
445
446        if sub.any_left() {
447            Err(InvalidMessage::TrailingData("SingleProtocolName"))
448        } else {
449            Ok(Self(item))
450        }
451    }
452}
453
454impl AsRef<ProtocolName> for SingleProtocolName {
455    fn as_ref(&self) -> &ProtocolName {
456        &self.0
457    }
458}
459
460// --- TLS 1.3 Key shares ---
461#[derive(Clone, Debug)]
462pub(crate) struct KeyShareEntry {
463    pub(crate) group: NamedGroup,
464    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
465    pub(crate) payload: PayloadU16<NonEmpty>,
466}
467
468impl KeyShareEntry {
469    pub(crate) fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
470        Self {
471            group,
472            payload: PayloadU16::new(payload.into()),
473        }
474    }
475}
476
477impl Codec<'_> for KeyShareEntry {
478    fn encode(&self, bytes: &mut Vec<u8>) {
479        self.group.encode(bytes);
480        self.payload.encode(bytes);
481    }
482
483    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
484        let group = NamedGroup::read(r)?;
485        let payload = PayloadU16::read(r)?;
486
487        Ok(Self { group, payload })
488    }
489}
490
491// --- TLS 1.3 PresharedKey offers ---
492#[derive(Clone, Debug)]
493pub(crate) struct PresharedKeyIdentity {
494    /// RFC8446: `opaque identity<1..2^16-1>;`
495    pub(crate) identity: PayloadU16<NonEmpty>,
496    pub(crate) obfuscated_ticket_age: u32,
497}
498
499impl PresharedKeyIdentity {
500    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
501        Self {
502            identity: PayloadU16::new(id),
503            obfuscated_ticket_age: age,
504        }
505    }
506}
507
508impl Codec<'_> for PresharedKeyIdentity {
509    fn encode(&self, bytes: &mut Vec<u8>) {
510        self.identity.encode(bytes);
511        self.obfuscated_ticket_age.encode(bytes);
512    }
513
514    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
515        Ok(Self {
516            identity: PayloadU16::read(r)?,
517            obfuscated_ticket_age: u32::read(r)?,
518        })
519    }
520}
521
522/// RFC8446: `PskIdentity identities<7..2^16-1>;`
523impl TlsListElement for PresharedKeyIdentity {
524    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
525        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
526    };
527}
528
529wrapped_payload!(
530    /// RFC8446: `opaque PskBinderEntry<32..255>;`
531    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
532);
533
534/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
535impl TlsListElement for PresharedKeyBinder {
536    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
537        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
538    };
539}
540
541#[derive(Clone, Debug)]
542pub(crate) struct PresharedKeyOffer {
543    pub(crate) identities: Vec<PresharedKeyIdentity>,
544    pub(crate) binders: Vec<PresharedKeyBinder>,
545}
546
547impl PresharedKeyOffer {
548    /// Make a new one with one entry.
549    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
550        Self {
551            identities: vec![id],
552            binders: vec![PresharedKeyBinder::from(binder)],
553        }
554    }
555}
556
557impl Codec<'_> for PresharedKeyOffer {
558    fn encode(&self, bytes: &mut Vec<u8>) {
559        self.identities.encode(bytes);
560        self.binders.encode(bytes);
561    }
562
563    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
564        Ok(Self {
565            identities: Vec::read(r)?,
566            binders: Vec::read(r)?,
567        })
568    }
569}
570
571// --- RFC6066 certificate status request ---
572wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
573
574/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
575impl TlsListElement for ResponderId {
576    const SIZE_LEN: ListLength = ListLength::U16;
577}
578
579#[derive(Clone, Debug)]
580pub(crate) struct OcspCertificateStatusRequest {
581    pub(crate) responder_ids: Vec<ResponderId>,
582    pub(crate) extensions: PayloadU16,
583}
584
585impl Codec<'_> for OcspCertificateStatusRequest {
586    fn encode(&self, bytes: &mut Vec<u8>) {
587        CertificateStatusType::OCSP.encode(bytes);
588        self.responder_ids.encode(bytes);
589        self.extensions.encode(bytes);
590    }
591
592    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
593        Ok(Self {
594            responder_ids: Vec::read(r)?,
595            extensions: PayloadU16::read(r)?,
596        })
597    }
598}
599
600#[derive(Clone, Debug)]
601pub(crate) enum CertificateStatusRequest {
602    Ocsp(OcspCertificateStatusRequest),
603    Unknown((CertificateStatusType, Payload<'static>)),
604}
605
606impl Codec<'_> for CertificateStatusRequest {
607    fn encode(&self, bytes: &mut Vec<u8>) {
608        match self {
609            Self::Ocsp(r) => r.encode(bytes),
610            Self::Unknown((typ, payload)) => {
611                typ.encode(bytes);
612                payload.encode(bytes);
613            }
614        }
615    }
616
617    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
618        let typ = CertificateStatusType::read(r)?;
619
620        match typ {
621            CertificateStatusType::OCSP => {
622                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
623                Ok(Self::Ocsp(ocsp_req))
624            }
625            _ => {
626                let data = Payload::read(r).into_owned();
627                Ok(Self::Unknown((typ, data)))
628            }
629        }
630    }
631}
632
633impl CertificateStatusRequest {
634    pub(crate) fn build_ocsp() -> Self {
635        let ocsp = OcspCertificateStatusRequest {
636            responder_ids: Vec::new(),
637            extensions: PayloadU16::empty(),
638        };
639        Self::Ocsp(ocsp)
640    }
641}
642
643// ---
644
645/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
646#[derive(Clone, Copy, Debug, Default)]
647pub(crate) struct PskKeyExchangeModes {
648    pub(crate) psk_dhe: bool,
649    pub(crate) psk: bool,
650}
651
652impl Codec<'_> for PskKeyExchangeModes {
653    fn encode(&self, bytes: &mut Vec<u8>) {
654        let inner = LengthPrefixedBuffer::new(PskKeyExchangeMode::SIZE_LEN, bytes);
655        if self.psk_dhe {
656            PskKeyExchangeMode::PSK_DHE_KE.encode(inner.buf);
657        }
658        if self.psk {
659            PskKeyExchangeMode::PSK_KE.encode(inner.buf);
660        }
661    }
662
663    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
664        let mut psk_dhe = false;
665        let mut psk = false;
666
667        for ke in TlsListIter::<PskKeyExchangeMode>::new(reader)? {
668            match ke? {
669                PskKeyExchangeMode::PSK_DHE_KE => psk_dhe = true,
670                PskKeyExchangeMode::PSK_KE => psk = true,
671                _ => continue,
672            };
673        }
674
675        Ok(Self { psk_dhe, psk })
676    }
677}
678
679impl TlsListElement for PskKeyExchangeMode {
680    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
681        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
682    };
683}
684
685/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
686impl TlsListElement for KeyShareEntry {
687    const SIZE_LEN: ListLength = ListLength::U16;
688}
689
690/// The body of the `SupportedVersions` extension when it appears in a
691/// `ClientHello`
692///
693/// This is documented as a preference-order vector, but we (as a server)
694/// ignore the preference of the client.
695///
696/// RFC8446: `ProtocolVersion versions<2..254>;`
697#[derive(Clone, Copy, Debug, Default)]
698pub(crate) struct SupportedProtocolVersions {
699    pub(crate) tls13: bool,
700    pub(crate) tls12: bool,
701}
702
703impl SupportedProtocolVersions {
704    /// Return true if `filter` returns true for any enabled version.
705    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
706        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
707            return true;
708        }
709        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
710            return true;
711        }
712        false
713    }
714
715    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
716        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
717    };
718}
719
720impl Codec<'_> for SupportedProtocolVersions {
721    fn encode(&self, bytes: &mut Vec<u8>) {
722        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
723        if self.tls13 {
724            ProtocolVersion::TLSv1_3.encode(inner.buf);
725        }
726        if self.tls12 {
727            ProtocolVersion::TLSv1_2.encode(inner.buf);
728        }
729    }
730
731    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
732        let mut tls12 = false;
733        let mut tls13 = false;
734
735        for pv in TlsListIter::<ProtocolVersion>::new(reader)? {
736            match pv? {
737                ProtocolVersion::TLSv1_3 => tls13 = true,
738                ProtocolVersion::TLSv1_2 => tls12 = true,
739                _ => continue,
740            };
741        }
742
743        Ok(Self { tls13, tls12 })
744    }
745}
746
747impl TlsListElement for ProtocolVersion {
748    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
749        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
750    };
751}
752
753/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
754///
755/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
756impl TlsListElement for CertificateType {
757    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
758        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
759    };
760}
761
762/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
763impl TlsListElement for CertificateCompressionAlgorithm {
764    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
765        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
766    };
767}
768
769/// A precursor to `ClientExtensions`, allowing customisation.
770///
771/// This is smaller than `ClientExtensions`, as it only contains the extensions
772/// we need to vary between different protocols (eg, TCP-TLS versus QUIC).
773#[derive(Clone, Default)]
774pub(crate) struct ClientExtensionsInput<'a> {
775    /// QUIC transport parameters
776    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
777
778    /// ALPN protocols
779    pub(crate) protocols: Option<Vec<ProtocolName>>,
780}
781
782impl ClientExtensionsInput<'_> {
783    pub(crate) fn from_alpn(alpn_protocols: Vec<Vec<u8>>) -> ClientExtensionsInput<'static> {
784        let protocols = match alpn_protocols.is_empty() {
785            true => None,
786            false => Some(
787                alpn_protocols
788                    .into_iter()
789                    .map(ProtocolName::from)
790                    .collect::<Vec<_>>(),
791            ),
792        };
793
794        ClientExtensionsInput {
795            transport_parameters: None,
796            protocols,
797        }
798    }
799
800    pub(crate) fn into_owned(self) -> ClientExtensionsInput<'static> {
801        let Self {
802            transport_parameters,
803            protocols,
804        } = self;
805        ClientExtensionsInput {
806            transport_parameters: transport_parameters.map(|x| x.into_owned()),
807            protocols,
808        }
809    }
810}
811
812#[derive(Clone)]
813pub(crate) enum TransportParameters<'a> {
814    /// QUIC transport parameters (RFC9001)
815    Quic(Payload<'a>),
816}
817
818impl TransportParameters<'_> {
819    pub(crate) fn into_owned(self) -> TransportParameters<'static> {
820        match self {
821            Self::Quic(v) => TransportParameters::Quic(v.into_owned()),
822        }
823    }
824}
825
826extension_struct! {
827    /// A representation of extensions present in a `ClientHello` message
828    ///
829    /// All extensions are optional (by definition) so are represented with `Option<T>`.
830    ///
831    /// Some extensions have an empty value and are represented with Option<()>.
832    ///
833    /// Unknown extensions are dropped during parsing.
834    pub(crate) struct ClientExtensions<'a> {
835        /// Requested server name indication (RFC6066)
836        ExtensionType::ServerName =>
837            pub(crate) server_name: Option<ServerNamePayload<'a>>,
838
839        /// Certificate status is requested (RFC6066)
840        ExtensionType::StatusRequest =>
841            pub(crate) certificate_status_request: Option<CertificateStatusRequest>,
842
843        /// Supported groups (RFC4492/RFC8446)
844        ExtensionType::EllipticCurves =>
845            pub(crate) named_groups: Option<Vec<NamedGroup>>,
846
847        /// Supported EC point formats (RFC4492)
848        ExtensionType::ECPointFormats =>
849            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
850
851        /// Supported signature schemes (RFC5246/RFC8446)
852        ExtensionType::SignatureAlgorithms =>
853            pub(crate) signature_schemes: Option<Vec<SignatureScheme>>,
854
855        /// Offered ALPN protocols (RFC6066)
856        ExtensionType::ALProtocolNegotiation =>
857            pub(crate) protocols: Option<Vec<ProtocolName>>,
858
859        /// Available client certificate types (RFC7250)
860        ExtensionType::ClientCertificateType =>
861            pub(crate) client_certificate_types: Option<Vec<CertificateType>>,
862
863        /// Acceptable server certificate types (RFC7250)
864        ExtensionType::ServerCertificateType =>
865            pub(crate) server_certificate_types: Option<Vec<CertificateType>>,
866
867        /// Extended master secret is requested (RFC7627)
868        ExtensionType::ExtendedMasterSecret =>
869            pub(crate) extended_master_secret_request: Option<()>,
870
871        /// Offered certificate compression methods (RFC8879)
872        ExtensionType::CompressCertificate =>
873            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
874
875        /// Session ticket offer or request (RFC5077/RFC8446)
876        ExtensionType::SessionTicket =>
877            pub(crate) session_ticket: Option<ClientSessionTicket>,
878
879        /// Offered preshared keys (RFC8446)
880        ExtensionType::PreSharedKey =>
881            pub(crate) preshared_key_offer: Option<PresharedKeyOffer>,
882
883        /// Early data is requested (RFC8446)
884        ExtensionType::EarlyData =>
885            pub(crate) early_data_request: Option<()>,
886
887        /// Supported TLS versions (RFC8446)
888        ExtensionType::SupportedVersions =>
889            pub(crate) supported_versions: Option<SupportedProtocolVersions>,
890
891        /// Stateless HelloRetryRequest cookie (RFC8446)
892        ExtensionType::Cookie =>
893            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
894
895        /// Offered preshared key modes (RFC8446)
896        ExtensionType::PSKKeyExchangeModes =>
897            pub(crate) preshared_key_modes: Option<PskKeyExchangeModes>,
898
899        /// Certificate authority names (RFC8446)
900        ExtensionType::CertificateAuthorities =>
901            pub(crate) certificate_authority_names: Option<Vec<DistinguishedName>>,
902
903        /// Offered key exchange shares (RFC8446)
904        ExtensionType::KeyShare =>
905            pub(crate) key_shares: Option<Vec<KeyShareEntry>>,
906
907        /// QUIC transport parameters (RFC9001)
908        ExtensionType::TransportParameters =>
909            pub(crate) transport_parameters: Option<Payload<'a>>,
910
911        /// Secure renegotiation (RFC5746)
912        ExtensionType::RenegotiationInfo =>
913            pub(crate) renegotiation_info: Option<PayloadU8>,
914
915        /// Encrypted inner client hello (draft-ietf-tls-esni)
916        ExtensionType::EncryptedClientHello =>
917            pub(crate) encrypted_client_hello: Option<EncryptedClientHello>,
918
919        /// Encrypted client hello outer extensions (draft-ietf-tls-esni)
920        ExtensionType::EncryptedClientHelloOuterExtensions =>
921            pub(crate) encrypted_client_hello_outer: Option<Vec<ExtensionType>>,
922    } + {
923        /// Order randomization seed.
924        pub(crate) order_seed: u16,
925
926        /// Extensions that must appear contiguously.
927        pub(crate) contiguous_extensions: Vec<ExtensionType>,
928    }
929}
930
931impl ClientExtensions<'_> {
932    pub(crate) fn into_owned(self) -> ClientExtensions<'static> {
933        let Self {
934            server_name,
935            certificate_status_request,
936            named_groups,
937            ec_point_formats,
938            signature_schemes,
939            protocols,
940            client_certificate_types,
941            server_certificate_types,
942            extended_master_secret_request,
943            certificate_compression_algorithms,
944            session_ticket,
945            preshared_key_offer,
946            early_data_request,
947            supported_versions,
948            cookie,
949            preshared_key_modes,
950            certificate_authority_names,
951            key_shares,
952            transport_parameters,
953            renegotiation_info,
954            encrypted_client_hello,
955            encrypted_client_hello_outer,
956            order_seed,
957            contiguous_extensions,
958        } = self;
959        ClientExtensions {
960            server_name: server_name.map(|x| x.into_owned()),
961            certificate_status_request,
962            named_groups,
963            ec_point_formats,
964            signature_schemes,
965            protocols,
966            client_certificate_types,
967            server_certificate_types,
968            extended_master_secret_request,
969            certificate_compression_algorithms,
970            session_ticket,
971            preshared_key_offer,
972            early_data_request,
973            supported_versions,
974            cookie,
975            preshared_key_modes,
976            certificate_authority_names,
977            key_shares,
978            transport_parameters: transport_parameters.map(|x| x.into_owned()),
979            renegotiation_info,
980            encrypted_client_hello,
981            encrypted_client_hello_outer,
982            order_seed,
983            contiguous_extensions,
984        }
985    }
986
987    pub(crate) fn used_extensions_in_encoding_order(&self) -> Vec<ExtensionType> {
988        let mut exts = self.order_insensitive_extensions_in_random_order();
989        exts.extend(&self.contiguous_extensions);
990
991        if self
992            .encrypted_client_hello_outer
993            .is_some()
994        {
995            exts.push(ExtensionType::EncryptedClientHelloOuterExtensions);
996        }
997        if self.encrypted_client_hello.is_some() {
998            exts.push(ExtensionType::EncryptedClientHello);
999        }
1000        if self.preshared_key_offer.is_some() {
1001            exts.push(ExtensionType::PreSharedKey);
1002        }
1003        exts
1004    }
1005
1006    /// Returns extensions which don't need a specific order, in randomized order.
1007    ///
1008    /// Extensions are encoded in three portions:
1009    ///
1010    /// - First, extensions not otherwise dealt with by other cases.
1011    ///   These are encoded in random order, controlled by `self.order_seed`,
1012    ///   and this is the set of extensions returned by this function.
1013    ///
1014    /// - Second, extensions named in `self.contiguous_extensions`, in the order
1015    ///   given by that field.
1016    ///
1017    /// - Lastly, any ECH and PSK extensions (in that order).  These
1018    ///   are required to be last by the standard.
1019    fn order_insensitive_extensions_in_random_order(&self) -> Vec<ExtensionType> {
1020        let mut order = self.collect_used();
1021
1022        // Remove extensions which have specific order requirements.
1023        order.retain(|ext| {
1024            !(matches!(
1025                ext,
1026                ExtensionType::PreSharedKey
1027                    | ExtensionType::EncryptedClientHello
1028                    | ExtensionType::EncryptedClientHelloOuterExtensions
1029            ) || self.contiguous_extensions.contains(ext))
1030        });
1031
1032        order.sort_by_cached_key(|new_ext| {
1033            let seed = ((self.order_seed as u32) << 16) | (u16::from(*new_ext) as u32);
1034            low_quality_integer_hash(seed)
1035        });
1036
1037        order
1038    }
1039}
1040
1041impl<'a> Codec<'a> for ClientExtensions<'a> {
1042    fn encode(&self, bytes: &mut Vec<u8>) {
1043        let order = self.used_extensions_in_encoding_order();
1044
1045        if order.is_empty() {
1046            return;
1047        }
1048
1049        let body = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1050        for item in order {
1051            self.encode_one(item, body.buf);
1052        }
1053    }
1054
1055    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1056        let mut out = Self::default();
1057
1058        // extensions length can be absent if no extensions
1059        if !r.any_left() {
1060            return Ok(out);
1061        }
1062
1063        let mut checker = DuplicateExtensionChecker::new();
1064
1065        let len = usize::from(u16::read(r)?);
1066        let mut sub = r.sub(len)?;
1067
1068        while sub.any_left() {
1069            let typ = out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1070
1071            // PreSharedKey offer must come last
1072            if typ == ExtensionType::PreSharedKey && sub.any_left() {
1073                return Err(InvalidMessage::PreSharedKeyIsNotFinalExtension);
1074            }
1075        }
1076
1077        Ok(out)
1078    }
1079}
1080
1081fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
1082    let dns_name_str = dns_name.as_ref();
1083
1084    // RFC6066: "The hostname is represented as a byte string using
1085    // ASCII encoding without a trailing dot"
1086    if dns_name_str.ends_with('.') {
1087        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
1088        DnsName::try_from(trimmed)
1089            .unwrap()
1090            .to_owned()
1091    } else {
1092        dns_name.to_owned()
1093    }
1094}
1095
1096#[derive(Clone, Debug)]
1097pub(crate) enum ClientSessionTicket {
1098    Request,
1099    Offer(Payload<'static>),
1100}
1101
1102impl<'a> Codec<'a> for ClientSessionTicket {
1103    fn encode(&self, bytes: &mut Vec<u8>) {
1104        match self {
1105            Self::Request => (),
1106            Self::Offer(p) => p.encode(bytes),
1107        }
1108    }
1109
1110    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1111        Ok(match r.left() {
1112            0 => Self::Request,
1113            _ => Self::Offer(Payload::read(r).into_owned()),
1114        })
1115    }
1116}
1117
1118#[derive(Default)]
1119pub(crate) struct ServerExtensionsInput<'a> {
1120    /// QUIC transport parameters
1121    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
1122}
1123
1124extension_struct! {
1125    pub(crate) struct ServerExtensions<'a> {
1126        /// Supported EC point formats (RFC4492)
1127        ExtensionType::ECPointFormats =>
1128            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
1129
1130        /// Server name indication acknowledgement (RFC6066)
1131        ExtensionType::ServerName =>
1132            pub(crate) server_name_ack: Option<()>,
1133
1134        /// Session ticket acknowledgement (RFC5077)
1135        ExtensionType::SessionTicket =>
1136            pub(crate) session_ticket_ack: Option<()>,
1137
1138        ExtensionType::RenegotiationInfo =>
1139            pub(crate) renegotiation_info: Option<PayloadU8>,
1140
1141        /// Selected ALPN protocol (RFC7301)
1142        ExtensionType::ALProtocolNegotiation =>
1143            pub(crate) selected_protocol: Option<SingleProtocolName>,
1144
1145        /// Key exchange server share (RFC8446)
1146        ExtensionType::KeyShare =>
1147            pub(crate) key_share: Option<KeyShareEntry>,
1148
1149        /// Selected preshared key index (RFC8446)
1150        ExtensionType::PreSharedKey =>
1151            pub(crate) preshared_key: Option<u16>,
1152
1153        /// Required client certificate type (RFC7250)
1154        ExtensionType::ClientCertificateType =>
1155            pub(crate) client_certificate_type: Option<CertificateType>,
1156
1157        /// Selected server certificate type (RFC7250)
1158        ExtensionType::ServerCertificateType =>
1159            pub(crate) server_certificate_type: Option<CertificateType>,
1160
1161        /// Extended master secret is in use (RFC7627)
1162        ExtensionType::ExtendedMasterSecret =>
1163            pub(crate) extended_master_secret_ack: Option<()>,
1164
1165        /// Certificate status acknowledgement (RFC6066)
1166        ExtensionType::StatusRequest =>
1167            pub(crate) certificate_status_request_ack: Option<()>,
1168
1169        /// Selected TLS version (RFC8446)
1170        ExtensionType::SupportedVersions =>
1171            pub(crate) selected_version: Option<ProtocolVersion>,
1172
1173        /// QUIC transport parameters (RFC9001)
1174        ExtensionType::TransportParameters =>
1175            pub(crate) transport_parameters: Option<Payload<'a>>,
1176
1177        /// Early data is accepted (RFC8446)
1178        ExtensionType::EarlyData =>
1179            pub(crate) early_data_ack: Option<()>,
1180
1181        /// Encrypted inner client hello response (draft-ietf-tls-esni)
1182        ExtensionType::EncryptedClientHello =>
1183            pub(crate) encrypted_client_hello_ack: Option<ServerEncryptedClientHello>,
1184    } + {
1185        pub(crate) unknown_extensions: BTreeSet<u16>,
1186    }
1187}
1188
1189impl ServerExtensions<'_> {
1190    fn into_owned(self) -> ServerExtensions<'static> {
1191        let Self {
1192            ec_point_formats,
1193            server_name_ack,
1194            session_ticket_ack,
1195            renegotiation_info,
1196            selected_protocol,
1197            key_share,
1198            preshared_key,
1199            client_certificate_type,
1200            server_certificate_type,
1201            extended_master_secret_ack,
1202            certificate_status_request_ack,
1203            selected_version,
1204            transport_parameters,
1205            early_data_ack,
1206            encrypted_client_hello_ack,
1207            unknown_extensions,
1208        } = self;
1209        ServerExtensions {
1210            ec_point_formats,
1211            server_name_ack,
1212            session_ticket_ack,
1213            renegotiation_info,
1214            selected_protocol,
1215            key_share,
1216            preshared_key,
1217            client_certificate_type,
1218            server_certificate_type,
1219            extended_master_secret_ack,
1220            certificate_status_request_ack,
1221            selected_version,
1222            transport_parameters: transport_parameters.map(|x| x.into_owned()),
1223            early_data_ack,
1224            encrypted_client_hello_ack,
1225            unknown_extensions,
1226        }
1227    }
1228}
1229
1230impl<'a> Codec<'a> for ServerExtensions<'a> {
1231    fn encode(&self, bytes: &mut Vec<u8>) {
1232        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1233
1234        for ext in Self::ALL_EXTENSIONS {
1235            self.encode_one(*ext, extensions.buf);
1236        }
1237    }
1238
1239    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1240        let mut out = Self::default();
1241        let mut checker = DuplicateExtensionChecker::new();
1242
1243        let len = usize::from(u16::read(r)?);
1244        let mut sub = r.sub(len)?;
1245
1246        while sub.any_left() {
1247            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1248        }
1249
1250        out.unknown_extensions = checker.0;
1251        Ok(out)
1252    }
1253}
1254
1255#[derive(Clone, Debug)]
1256pub(crate) struct ClientHelloPayload {
1257    pub(crate) client_version: ProtocolVersion,
1258    pub(crate) random: Random,
1259    pub(crate) session_id: SessionId,
1260    pub(crate) cipher_suites: Vec<CipherSuite>,
1261    pub(crate) compression_methods: Vec<Compression>,
1262    pub(crate) extensions: Box<ClientExtensions<'static>>,
1263}
1264
1265impl ClientHelloPayload {
1266    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1267        let mut bytes = Vec::new();
1268        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1269        bytes
1270    }
1271
1272    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1273        self.client_version.encode(bytes);
1274        self.random.encode(bytes);
1275
1276        match purpose {
1277            // SessionID is required to be empty in the encoded inner client hello.
1278            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1279            _ => self.session_id.encode(bytes),
1280        }
1281
1282        self.cipher_suites.encode(bytes);
1283        self.compression_methods.encode(bytes);
1284
1285        let to_compress = match purpose {
1286            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1287            _ => {
1288                self.extensions.encode(bytes);
1289                return;
1290            }
1291        };
1292
1293        let mut compressed = self.extensions.clone();
1294
1295        // First, eliminate the full-fat versions of the extensions
1296        for e in &to_compress {
1297            compressed.clear(*e);
1298        }
1299
1300        // Replace with the marker noting which extensions were elided.
1301        compressed.encrypted_client_hello_outer = Some(to_compress);
1302
1303        // And encode as normal.
1304        compressed.encode(bytes);
1305    }
1306
1307    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1308        self.key_shares
1309            .as_ref()
1310            .map(|entries| {
1311                has_duplicates::<_, _, u16>(
1312                    entries
1313                        .iter()
1314                        .map(|kse| u16::from(kse.group)),
1315                )
1316            })
1317            .unwrap_or_default()
1318    }
1319
1320    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1321        if let Some(algs) = &self.certificate_compression_algorithms {
1322            has_duplicates::<_, _, u16>(algs.iter().cloned())
1323        } else {
1324            false
1325        }
1326    }
1327}
1328
1329impl Codec<'_> for ClientHelloPayload {
1330    fn encode(&self, bytes: &mut Vec<u8>) {
1331        self.payload_encode(bytes, Encoding::Standard)
1332    }
1333
1334    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1335        let ret = Self {
1336            client_version: ProtocolVersion::read(r)?,
1337            random: Random::read(r)?,
1338            session_id: SessionId::read(r)?,
1339            cipher_suites: Vec::read(r)?,
1340            compression_methods: Vec::read(r)?,
1341            extensions: Box::new(ClientExtensions::read(r)?.into_owned()),
1342        };
1343
1344        match r.any_left() {
1345            true => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1346            false => Ok(ret),
1347        }
1348    }
1349}
1350
1351impl Deref for ClientHelloPayload {
1352    type Target = ClientExtensions<'static>;
1353    fn deref(&self) -> &Self::Target {
1354        &self.extensions
1355    }
1356}
1357
1358impl DerefMut for ClientHelloPayload {
1359    fn deref_mut(&mut self) -> &mut Self::Target {
1360        &mut self.extensions
1361    }
1362}
1363
1364/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1365impl TlsListElement for CipherSuite {
1366    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1367        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1368    };
1369}
1370
1371/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1372impl TlsListElement for Compression {
1373    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1374        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1375    };
1376}
1377
1378/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1379impl TlsListElement for ExtensionType {
1380    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1381        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1382    };
1383}
1384
1385extension_struct! {
1386    /// A representation of extensions present in a `HelloRetryRequest` message
1387    pub(crate) struct HelloRetryRequestExtensions<'a> {
1388        ExtensionType::KeyShare =>
1389            pub(crate) key_share: Option<NamedGroup>,
1390
1391        ExtensionType::Cookie =>
1392            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
1393
1394        ExtensionType::SupportedVersions =>
1395            pub(crate) supported_versions: Option<ProtocolVersion>,
1396
1397        ExtensionType::EncryptedClientHello =>
1398            pub(crate) encrypted_client_hello: Option<Payload<'a>>,
1399    } + {
1400        /// Records decoding order of records, and controls encoding order.
1401        pub(crate) order: Option<Vec<ExtensionType>>,
1402    }
1403}
1404
1405impl HelloRetryRequestExtensions<'_> {
1406    fn into_owned(self) -> HelloRetryRequestExtensions<'static> {
1407        let Self {
1408            key_share,
1409            cookie,
1410            supported_versions,
1411            encrypted_client_hello,
1412            order,
1413        } = self;
1414        HelloRetryRequestExtensions {
1415            key_share,
1416            cookie,
1417            supported_versions,
1418            encrypted_client_hello: encrypted_client_hello.map(|x| x.into_owned()),
1419            order,
1420        }
1421    }
1422}
1423
1424impl<'a> Codec<'a> for HelloRetryRequestExtensions<'a> {
1425    fn encode(&self, bytes: &mut Vec<u8>) {
1426        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1427
1428        for ext in self
1429            .order
1430            .as_deref()
1431            .unwrap_or(Self::ALL_EXTENSIONS)
1432        {
1433            self.encode_one(*ext, extensions.buf);
1434        }
1435    }
1436
1437    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1438        let mut out = Self::default();
1439
1440        // we must record order, so re-encoding round trips.  this is needed,
1441        // unfortunately, for ECH HRR confirmation
1442        let mut order = vec![];
1443
1444        let len = usize::from(u16::read(r)?);
1445        let mut sub = r.sub(len)?;
1446
1447        while sub.any_left() {
1448            let typ = out.read_one(&mut sub, |_unk| {
1449                Err(InvalidMessage::UnknownHelloRetryRequestExtension)
1450            })?;
1451
1452            order.push(typ);
1453        }
1454
1455        out.order = Some(order);
1456        Ok(out)
1457    }
1458}
1459
1460#[derive(Clone, Debug)]
1461pub(crate) struct HelloRetryRequest {
1462    pub(crate) legacy_version: ProtocolVersion,
1463    pub(crate) session_id: SessionId,
1464    pub(crate) cipher_suite: CipherSuite,
1465    pub(crate) extensions: HelloRetryRequestExtensions<'static>,
1466}
1467
1468impl Codec<'_> for HelloRetryRequest {
1469    fn encode(&self, bytes: &mut Vec<u8>) {
1470        self.payload_encode(bytes, Encoding::Standard)
1471    }
1472
1473    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1474        let session_id = SessionId::read(r)?;
1475        let cipher_suite = CipherSuite::read(r)?;
1476        let compression = Compression::read(r)?;
1477
1478        if compression != Compression::Null {
1479            return Err(InvalidMessage::UnsupportedCompression);
1480        }
1481
1482        Ok(Self {
1483            legacy_version: ProtocolVersion::Unknown(0),
1484            session_id,
1485            cipher_suite,
1486            extensions: HelloRetryRequestExtensions::read(r)?.into_owned(),
1487        })
1488    }
1489}
1490
1491impl HelloRetryRequest {
1492    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1493        self.legacy_version.encode(bytes);
1494        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1495        self.session_id.encode(bytes);
1496        self.cipher_suite.encode(bytes);
1497        Compression::Null.encode(bytes);
1498
1499        match purpose {
1500            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1501            // must have its payload replaced by 8 zero bytes.
1502            //
1503            // See draft-ietf-tls-esni-18 7.2.1:
1504            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1505            Encoding::EchConfirmation
1506                if self
1507                    .extensions
1508                    .encrypted_client_hello
1509                    .is_some() =>
1510            {
1511                let hrr_confirmation = [0u8; 8];
1512                HelloRetryRequestExtensions {
1513                    encrypted_client_hello: Some(Payload::Borrowed(&hrr_confirmation)),
1514                    ..self.extensions.clone()
1515                }
1516                .encode(bytes);
1517            }
1518            _ => self.extensions.encode(bytes),
1519        }
1520    }
1521}
1522
1523impl Deref for HelloRetryRequest {
1524    type Target = HelloRetryRequestExtensions<'static>;
1525    fn deref(&self) -> &Self::Target {
1526        &self.extensions
1527    }
1528}
1529
1530impl DerefMut for HelloRetryRequest {
1531    fn deref_mut(&mut self) -> &mut Self::Target {
1532        &mut self.extensions
1533    }
1534}
1535
1536#[derive(Clone, Debug)]
1537pub(crate) struct ServerHelloPayload {
1538    pub(crate) legacy_version: ProtocolVersion,
1539    pub(crate) random: Random,
1540    pub(crate) session_id: SessionId,
1541    pub(crate) cipher_suite: CipherSuite,
1542    pub(crate) compression_method: Compression,
1543    pub(crate) extensions: Box<ServerExtensions<'static>>,
1544}
1545
1546impl Codec<'_> for ServerHelloPayload {
1547    fn encode(&self, bytes: &mut Vec<u8>) {
1548        self.payload_encode(bytes, Encoding::Standard)
1549    }
1550
1551    // minus version and random, which have already been read.
1552    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1553        let session_id = SessionId::read(r)?;
1554        let suite = CipherSuite::read(r)?;
1555        let compression = Compression::read(r)?;
1556
1557        // RFC5246:
1558        // "The presence of extensions can be detected by determining whether
1559        //  there are bytes following the compression_method field at the end of
1560        //  the ServerHello."
1561        let extensions = Box::new(
1562            if r.any_left() {
1563                ServerExtensions::read(r)?
1564            } else {
1565                ServerExtensions::default()
1566            }
1567            .into_owned(),
1568        );
1569
1570        let ret = Self {
1571            legacy_version: ProtocolVersion::Unknown(0),
1572            random: ZERO_RANDOM,
1573            session_id,
1574            cipher_suite: suite,
1575            compression_method: compression,
1576            extensions,
1577        };
1578
1579        r.expect_empty("ServerHelloPayload")
1580            .map(|_| ret)
1581    }
1582}
1583
1584impl ServerHelloPayload {
1585    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1586        debug_assert!(
1587            !matches!(encoding, Encoding::EchConfirmation),
1588            "we cannot compute an ECH confirmation on a received ServerHello"
1589        );
1590
1591        self.legacy_version.encode(bytes);
1592        self.random.encode(bytes);
1593        self.session_id.encode(bytes);
1594        self.cipher_suite.encode(bytes);
1595        self.compression_method.encode(bytes);
1596        self.extensions.encode(bytes);
1597    }
1598}
1599
1600impl Deref for ServerHelloPayload {
1601    type Target = ServerExtensions<'static>;
1602    fn deref(&self) -> &Self::Target {
1603        &self.extensions
1604    }
1605}
1606
1607impl DerefMut for ServerHelloPayload {
1608    fn deref_mut(&mut self) -> &mut Self::Target {
1609        &mut self.extensions
1610    }
1611}
1612
1613#[derive(Clone, Default, Debug)]
1614pub(crate) struct CertificateChain<'a>(pub(crate) Vec<CertificateDer<'a>>);
1615
1616impl CertificateChain<'_> {
1617    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1618        CertificateChain(
1619            self.0
1620                .into_iter()
1621                .map(|c| c.into_owned())
1622                .collect(),
1623        )
1624    }
1625}
1626
1627impl<'a> Codec<'a> for CertificateChain<'a> {
1628    fn encode(&self, bytes: &mut Vec<u8>) {
1629        Vec::encode(&self.0, bytes)
1630    }
1631
1632    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1633        Vec::read(r).map(Self)
1634    }
1635}
1636
1637impl<'a> Deref for CertificateChain<'a> {
1638    type Target = [CertificateDer<'a>];
1639
1640    fn deref(&self) -> &[CertificateDer<'a>] {
1641        &self.0
1642    }
1643}
1644
1645impl TlsListElement for CertificateDer<'_> {
1646    const SIZE_LEN: ListLength = ListLength::U24 {
1647        max: CERTIFICATE_MAX_SIZE_LIMIT,
1648        error: InvalidMessage::CertificatePayloadTooLarge,
1649    };
1650}
1651
1652/// TLS has a 16MB size limit on any handshake message,
1653/// plus a 16MB limit on any given certificate.
1654///
1655/// We contract that to 64KB to limit the amount of memory allocation
1656/// that is directly controllable by the peer.
1657pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1658
1659extension_struct! {
1660    pub(crate) struct CertificateExtensions<'a> {
1661        ExtensionType::StatusRequest =>
1662            pub(crate) status: Option<CertificateStatus<'a>>,
1663    }
1664}
1665
1666impl CertificateExtensions<'_> {
1667    fn into_owned(self) -> CertificateExtensions<'static> {
1668        CertificateExtensions {
1669            status: self.status.map(|s| s.into_owned()),
1670        }
1671    }
1672}
1673
1674impl<'a> Codec<'a> for CertificateExtensions<'a> {
1675    fn encode(&self, bytes: &mut Vec<u8>) {
1676        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1677
1678        for ext in Self::ALL_EXTENSIONS {
1679            self.encode_one(*ext, extensions.buf);
1680        }
1681    }
1682
1683    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1684        let mut out = Self::default();
1685
1686        let len = usize::from(u16::read(r)?);
1687        let mut sub = r.sub(len)?;
1688
1689        while sub.any_left() {
1690            out.read_one(&mut sub, |_unk| {
1691                Err(InvalidMessage::UnknownCertificateExtension)
1692            })?;
1693        }
1694
1695        Ok(out)
1696    }
1697}
1698
1699#[derive(Debug)]
1700pub(crate) struct CertificateEntry<'a> {
1701    pub(crate) cert: CertificateDer<'a>,
1702    pub(crate) extensions: CertificateExtensions<'a>,
1703}
1704
1705impl<'a> Codec<'a> for CertificateEntry<'a> {
1706    fn encode(&self, bytes: &mut Vec<u8>) {
1707        self.cert.encode(bytes);
1708        self.extensions.encode(bytes);
1709    }
1710
1711    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1712        Ok(Self {
1713            cert: CertificateDer::read(r)?,
1714            extensions: CertificateExtensions::read(r)?.into_owned(),
1715        })
1716    }
1717}
1718
1719impl<'a> CertificateEntry<'a> {
1720    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1721        Self {
1722            cert,
1723            extensions: CertificateExtensions::default(),
1724        }
1725    }
1726
1727    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1728        CertificateEntry {
1729            cert: self.cert.into_owned(),
1730            extensions: self.extensions.into_owned(),
1731        }
1732    }
1733}
1734
1735impl TlsListElement for CertificateEntry<'_> {
1736    const SIZE_LEN: ListLength = ListLength::U24 {
1737        max: CERTIFICATE_MAX_SIZE_LIMIT,
1738        error: InvalidMessage::CertificatePayloadTooLarge,
1739    };
1740}
1741
1742#[derive(Debug)]
1743pub(crate) struct CertificatePayloadTls13<'a> {
1744    pub(crate) context: PayloadU8,
1745    pub(crate) entries: Vec<CertificateEntry<'a>>,
1746}
1747
1748impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1749    fn encode(&self, bytes: &mut Vec<u8>) {
1750        self.context.encode(bytes);
1751        self.entries.encode(bytes);
1752    }
1753
1754    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1755        Ok(Self {
1756            context: PayloadU8::read(r)?,
1757            entries: Vec::read(r)?,
1758        })
1759    }
1760}
1761
1762impl<'a> CertificatePayloadTls13<'a> {
1763    pub(crate) fn new(
1764        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1765        ocsp_response: Option<&'a [u8]>,
1766    ) -> Self {
1767        Self {
1768            context: PayloadU8::empty(),
1769            entries: certs
1770                // zip certificate iterator with `ocsp_response` followed by
1771                // an infinite-length iterator of `None`.
1772                .zip(
1773                    ocsp_response
1774                        .into_iter()
1775                        .map(Some)
1776                        .chain(iter::repeat(None)),
1777                )
1778                .map(|(cert, ocsp)| {
1779                    let mut e = CertificateEntry::new(cert.clone());
1780                    if let Some(ocsp) = ocsp {
1781                        e.extensions.status = Some(CertificateStatus::new(ocsp));
1782                    }
1783                    e
1784                })
1785                .collect(),
1786        }
1787    }
1788
1789    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1790        CertificatePayloadTls13 {
1791            context: self.context,
1792            entries: self
1793                .entries
1794                .into_iter()
1795                .map(CertificateEntry::into_owned)
1796                .collect(),
1797        }
1798    }
1799
1800    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1801        let Some(entry) = self.entries.first() else {
1802            return vec![];
1803        };
1804        entry
1805            .extensions
1806            .status
1807            .as_ref()
1808            .map(|status| {
1809                status
1810                    .ocsp_response
1811                    .0
1812                    .clone()
1813                    .into_vec()
1814            })
1815            .unwrap_or_default()
1816    }
1817
1818    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1819        CertificateChain(
1820            self.entries
1821                .into_iter()
1822                .map(|e| e.cert)
1823                .collect(),
1824        )
1825    }
1826}
1827
1828/// Describes supported key exchange mechanisms.
1829#[derive(Clone, Copy, Debug, PartialEq)]
1830#[non_exhaustive]
1831pub enum KeyExchangeAlgorithm {
1832    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1833    ///
1834    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1835    DHE,
1836    /// Key exchange performed via elliptic curve Diffie-Hellman.
1837    ECDHE,
1838}
1839
1840pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1841    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1842
1843// We don't support arbitrary curves.  It's a terrible
1844// idea and unnecessary attack surface.  Please,
1845// get a grip.
1846#[derive(Debug)]
1847pub(crate) struct EcParameters {
1848    pub(crate) curve_type: ECCurveType,
1849    pub(crate) named_group: NamedGroup,
1850}
1851
1852impl Codec<'_> for EcParameters {
1853    fn encode(&self, bytes: &mut Vec<u8>) {
1854        self.curve_type.encode(bytes);
1855        self.named_group.encode(bytes);
1856    }
1857
1858    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1859        let ct = ECCurveType::read(r)?;
1860        if ct != ECCurveType::NamedCurve {
1861            return Err(InvalidMessage::UnsupportedCurveType);
1862        }
1863
1864        let grp = NamedGroup::read(r)?;
1865
1866        Ok(Self {
1867            curve_type: ct,
1868            named_group: grp,
1869        })
1870    }
1871}
1872
1873#[cfg(feature = "tls12")]
1874pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1875    /// Decode a key exchange message given the key_exchange `algo`
1876    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1877}
1878
1879#[cfg(feature = "tls12")]
1880#[derive(Debug)]
1881pub(crate) enum ClientKeyExchangeParams {
1882    Ecdh(ClientEcdhParams),
1883    Dh(ClientDhParams),
1884}
1885
1886#[cfg(feature = "tls12")]
1887impl ClientKeyExchangeParams {
1888    pub(crate) fn pub_key(&self) -> &[u8] {
1889        match self {
1890            Self::Ecdh(ecdh) => &ecdh.public.0,
1891            Self::Dh(dh) => &dh.public.0,
1892        }
1893    }
1894
1895    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1896        match self {
1897            Self::Ecdh(ecdh) => ecdh.encode(buf),
1898            Self::Dh(dh) => dh.encode(buf),
1899        }
1900    }
1901}
1902
1903#[cfg(feature = "tls12")]
1904impl KxDecode<'_> for ClientKeyExchangeParams {
1905    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1906        use KeyExchangeAlgorithm::*;
1907        Ok(match algo {
1908            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1909            DHE => Self::Dh(ClientDhParams::read(r)?),
1910        })
1911    }
1912}
1913
1914#[cfg(feature = "tls12")]
1915#[derive(Debug)]
1916pub(crate) struct ClientEcdhParams {
1917    /// RFC4492: `opaque point <1..2^8-1>;`
1918    pub(crate) public: PayloadU8<NonEmpty>,
1919}
1920
1921#[cfg(feature = "tls12")]
1922impl Codec<'_> for ClientEcdhParams {
1923    fn encode(&self, bytes: &mut Vec<u8>) {
1924        self.public.encode(bytes);
1925    }
1926
1927    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1928        let pb = PayloadU8::read(r)?;
1929        Ok(Self { public: pb })
1930    }
1931}
1932
1933#[cfg(feature = "tls12")]
1934#[derive(Debug)]
1935pub(crate) struct ClientDhParams {
1936    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
1937    pub(crate) public: PayloadU16<NonEmpty>,
1938}
1939
1940#[cfg(feature = "tls12")]
1941impl Codec<'_> for ClientDhParams {
1942    fn encode(&self, bytes: &mut Vec<u8>) {
1943        self.public.encode(bytes);
1944    }
1945
1946    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1947        Ok(Self {
1948            public: PayloadU16::read(r)?,
1949        })
1950    }
1951}
1952
1953#[derive(Debug)]
1954pub(crate) struct ServerEcdhParams {
1955    pub(crate) curve_params: EcParameters,
1956    /// RFC4492: `opaque point <1..2^8-1>;`
1957    pub(crate) public: PayloadU8<NonEmpty>,
1958}
1959
1960impl ServerEcdhParams {
1961    #[cfg(feature = "tls12")]
1962    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1963        Self {
1964            curve_params: EcParameters {
1965                curve_type: ECCurveType::NamedCurve,
1966                named_group: kx.group(),
1967            },
1968            public: PayloadU8::new(kx.pub_key().to_vec()),
1969        }
1970    }
1971}
1972
1973impl Codec<'_> for ServerEcdhParams {
1974    fn encode(&self, bytes: &mut Vec<u8>) {
1975        self.curve_params.encode(bytes);
1976        self.public.encode(bytes);
1977    }
1978
1979    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1980        let cp = EcParameters::read(r)?;
1981        let pb = PayloadU8::read(r)?;
1982
1983        Ok(Self {
1984            curve_params: cp,
1985            public: pb,
1986        })
1987    }
1988}
1989
1990#[derive(Debug)]
1991#[allow(non_snake_case)]
1992pub(crate) struct ServerDhParams {
1993    /// RFC5246: `opaque dh_p<1..2^16-1>;`
1994    pub(crate) dh_p: PayloadU16<NonEmpty>,
1995    /// RFC5246: `opaque dh_g<1..2^16-1>;`
1996    pub(crate) dh_g: PayloadU16<NonEmpty>,
1997    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
1998    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
1999}
2000
2001impl ServerDhParams {
2002    #[cfg(feature = "tls12")]
2003    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2004        let Some(params) = kx.ffdhe_group() else {
2005            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
2006        };
2007
2008        Self {
2009            dh_p: PayloadU16::new(params.p.to_vec()),
2010            dh_g: PayloadU16::new(params.g.to_vec()),
2011            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
2012        }
2013    }
2014
2015    #[cfg(feature = "tls12")]
2016    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2017        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2018    }
2019}
2020
2021impl Codec<'_> for ServerDhParams {
2022    fn encode(&self, bytes: &mut Vec<u8>) {
2023        self.dh_p.encode(bytes);
2024        self.dh_g.encode(bytes);
2025        self.dh_Ys.encode(bytes);
2026    }
2027
2028    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2029        Ok(Self {
2030            dh_p: PayloadU16::read(r)?,
2031            dh_g: PayloadU16::read(r)?,
2032            dh_Ys: PayloadU16::read(r)?,
2033        })
2034    }
2035}
2036
2037#[allow(dead_code)]
2038#[derive(Debug)]
2039pub(crate) enum ServerKeyExchangeParams {
2040    Ecdh(ServerEcdhParams),
2041    Dh(ServerDhParams),
2042}
2043
2044impl ServerKeyExchangeParams {
2045    #[cfg(feature = "tls12")]
2046    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2047        match kx.group().key_exchange_algorithm() {
2048            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2049            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2050        }
2051    }
2052
2053    #[cfg(feature = "tls12")]
2054    pub(crate) fn pub_key(&self) -> &[u8] {
2055        match self {
2056            Self::Ecdh(ecdh) => &ecdh.public.0,
2057            Self::Dh(dh) => &dh.dh_Ys.0,
2058        }
2059    }
2060
2061    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2062        match self {
2063            Self::Ecdh(ecdh) => ecdh.encode(buf),
2064            Self::Dh(dh) => dh.encode(buf),
2065        }
2066    }
2067}
2068
2069#[cfg(feature = "tls12")]
2070impl KxDecode<'_> for ServerKeyExchangeParams {
2071    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2072        use KeyExchangeAlgorithm::*;
2073        Ok(match algo {
2074            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2075            DHE => Self::Dh(ServerDhParams::read(r)?),
2076        })
2077    }
2078}
2079
2080#[derive(Debug)]
2081pub(crate) struct ServerKeyExchange {
2082    pub(crate) params: ServerKeyExchangeParams,
2083    pub(crate) dss: DigitallySignedStruct,
2084}
2085
2086impl ServerKeyExchange {
2087    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2088        self.params.encode(buf);
2089        self.dss.encode(buf);
2090    }
2091}
2092
2093#[derive(Debug)]
2094pub(crate) enum ServerKeyExchangePayload {
2095    Known(ServerKeyExchange),
2096    Unknown(Payload<'static>),
2097}
2098
2099impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2100    fn from(value: ServerKeyExchange) -> Self {
2101        Self::Known(value)
2102    }
2103}
2104
2105impl Codec<'_> for ServerKeyExchangePayload {
2106    fn encode(&self, bytes: &mut Vec<u8>) {
2107        match self {
2108            Self::Known(x) => x.encode(bytes),
2109            Self::Unknown(x) => x.encode(bytes),
2110        }
2111    }
2112
2113    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2114        // read as Unknown, fully parse when we know the
2115        // KeyExchangeAlgorithm
2116        Ok(Self::Unknown(Payload::read(r).into_owned()))
2117    }
2118}
2119
2120impl ServerKeyExchangePayload {
2121    #[cfg(feature = "tls12")]
2122    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2123        if let Self::Unknown(unk) = self {
2124            let mut rd = Reader::init(unk.bytes());
2125
2126            let result = ServerKeyExchange {
2127                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2128                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2129            };
2130
2131            if !rd.any_left() {
2132                return Some(result);
2133            };
2134        }
2135
2136        None
2137    }
2138}
2139
2140/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2141impl TlsListElement for ClientCertificateType {
2142    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2143        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2144    };
2145}
2146
2147wrapped_payload!(
2148    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2149    ///
2150    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2151    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2152    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2153    ///
2154    /// ```ignore
2155    /// for name in distinguished_names {
2156    ///     use x509_parser::prelude::FromDer;
2157    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2158    /// }
2159    /// ```
2160    ///
2161    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2162    pub struct DistinguishedName,
2163    PayloadU16<NonEmpty>,
2164);
2165
2166impl DistinguishedName {
2167    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2168    ///
2169    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2170    ///
2171    /// ```ignore
2172    /// use x509_parser::prelude::FromDer;
2173    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2174    /// ```
2175    pub fn in_sequence(bytes: &[u8]) -> Self {
2176        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2177    }
2178}
2179
2180/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2181/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2182impl TlsListElement for DistinguishedName {
2183    const SIZE_LEN: ListLength = ListLength::U16;
2184}
2185
2186#[derive(Debug)]
2187pub(crate) struct CertificateRequestPayload {
2188    pub(crate) certtypes: Vec<ClientCertificateType>,
2189    pub(crate) sigschemes: Vec<SignatureScheme>,
2190    pub(crate) canames: Vec<DistinguishedName>,
2191}
2192
2193impl Codec<'_> for CertificateRequestPayload {
2194    fn encode(&self, bytes: &mut Vec<u8>) {
2195        self.certtypes.encode(bytes);
2196        self.sigschemes.encode(bytes);
2197        self.canames.encode(bytes);
2198    }
2199
2200    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2201        let certtypes = Vec::read(r)?;
2202        let sigschemes = Vec::read(r)?;
2203        let canames = Vec::read(r)?;
2204
2205        if sigschemes.is_empty() {
2206            warn!("meaningless CertificateRequest message");
2207            Err(InvalidMessage::NoSignatureSchemes)
2208        } else {
2209            Ok(Self {
2210                certtypes,
2211                sigschemes,
2212                canames,
2213            })
2214        }
2215    }
2216}
2217
2218extension_struct! {
2219    pub(crate) struct CertificateRequestExtensions {
2220        ExtensionType::SignatureAlgorithms =>
2221            pub(crate) signature_algorithms: Option<Vec<SignatureScheme>>,
2222
2223        ExtensionType::CertificateAuthorities =>
2224            pub(crate) authority_names: Option<Vec<DistinguishedName>>,
2225
2226        ExtensionType::CompressCertificate =>
2227            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
2228    }
2229}
2230
2231impl Codec<'_> for CertificateRequestExtensions {
2232    fn encode(&self, bytes: &mut Vec<u8>) {
2233        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2234
2235        for ext in Self::ALL_EXTENSIONS {
2236            self.encode_one(*ext, extensions.buf);
2237        }
2238    }
2239
2240    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2241        let mut out = Self::default();
2242
2243        let mut checker = DuplicateExtensionChecker::new();
2244
2245        let len = usize::from(u16::read(r)?);
2246        let mut sub = r.sub(len)?;
2247
2248        while sub.any_left() {
2249            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2250        }
2251
2252        if out
2253            .signature_algorithms
2254            .as_ref()
2255            .map(|algs| algs.is_empty())
2256            .unwrap_or_default()
2257        {
2258            return Err(InvalidMessage::NoSignatureSchemes);
2259        }
2260
2261        Ok(out)
2262    }
2263}
2264
2265#[derive(Debug)]
2266pub(crate) struct CertificateRequestPayloadTls13 {
2267    pub(crate) context: PayloadU8,
2268    pub(crate) extensions: CertificateRequestExtensions,
2269}
2270
2271impl Codec<'_> for CertificateRequestPayloadTls13 {
2272    fn encode(&self, bytes: &mut Vec<u8>) {
2273        self.context.encode(bytes);
2274        self.extensions.encode(bytes);
2275    }
2276
2277    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2278        let context = PayloadU8::read(r)?;
2279        let extensions = CertificateRequestExtensions::read(r)?;
2280
2281        Ok(Self {
2282            context,
2283            extensions,
2284        })
2285    }
2286}
2287
2288// -- NewSessionTicket --
2289#[derive(Debug)]
2290pub(crate) struct NewSessionTicketPayload {
2291    pub(crate) lifetime_hint: u32,
2292    // Tickets can be large (KB), so we deserialise this straight
2293    // into an Arc, so it can be passed directly into the client's
2294    // session object without copying.
2295    pub(crate) ticket: Arc<PayloadU16>,
2296}
2297
2298impl NewSessionTicketPayload {
2299    #[cfg(feature = "tls12")]
2300    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2301        Self {
2302            lifetime_hint,
2303            ticket: Arc::new(PayloadU16::new(ticket)),
2304        }
2305    }
2306}
2307
2308impl Codec<'_> for NewSessionTicketPayload {
2309    fn encode(&self, bytes: &mut Vec<u8>) {
2310        self.lifetime_hint.encode(bytes);
2311        self.ticket.encode(bytes);
2312    }
2313
2314    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2315        let lifetime = u32::read(r)?;
2316        let ticket = Arc::new(PayloadU16::read(r)?);
2317
2318        Ok(Self {
2319            lifetime_hint: lifetime,
2320            ticket,
2321        })
2322    }
2323}
2324
2325// -- NewSessionTicket electric boogaloo --
2326extension_struct! {
2327    pub(crate) struct NewSessionTicketExtensions {
2328        ExtensionType::EarlyData =>
2329            pub(crate) max_early_data_size: Option<u32>,
2330    }
2331}
2332
2333impl Codec<'_> for NewSessionTicketExtensions {
2334    fn encode(&self, bytes: &mut Vec<u8>) {
2335        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2336
2337        for ext in Self::ALL_EXTENSIONS {
2338            self.encode_one(*ext, extensions.buf);
2339        }
2340    }
2341
2342    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2343        let mut out = Self::default();
2344
2345        let mut checker = DuplicateExtensionChecker::new();
2346
2347        let len = usize::from(u16::read(r)?);
2348        let mut sub = r.sub(len)?;
2349
2350        while sub.any_left() {
2351            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2352        }
2353
2354        Ok(out)
2355    }
2356}
2357
2358#[derive(Debug)]
2359pub(crate) struct NewSessionTicketPayloadTls13 {
2360    pub(crate) lifetime: u32,
2361    pub(crate) age_add: u32,
2362    pub(crate) nonce: PayloadU8,
2363    pub(crate) ticket: Arc<PayloadU16>,
2364    pub(crate) extensions: NewSessionTicketExtensions,
2365}
2366
2367impl NewSessionTicketPayloadTls13 {
2368    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2369        Self {
2370            lifetime,
2371            age_add,
2372            nonce: PayloadU8::new(nonce),
2373            ticket: Arc::new(PayloadU16::new(ticket)),
2374            extensions: NewSessionTicketExtensions::default(),
2375        }
2376    }
2377}
2378
2379impl Codec<'_> for NewSessionTicketPayloadTls13 {
2380    fn encode(&self, bytes: &mut Vec<u8>) {
2381        self.lifetime.encode(bytes);
2382        self.age_add.encode(bytes);
2383        self.nonce.encode(bytes);
2384        self.ticket.encode(bytes);
2385        self.extensions.encode(bytes);
2386    }
2387
2388    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2389        let lifetime = u32::read(r)?;
2390        let age_add = u32::read(r)?;
2391        let nonce = PayloadU8::read(r)?;
2392        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2393        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2394            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2395            Err(err) => Err(err),
2396            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2397        }?);
2398        let extensions = NewSessionTicketExtensions::read(r)?;
2399
2400        Ok(Self {
2401            lifetime,
2402            age_add,
2403            nonce,
2404            ticket,
2405            extensions,
2406        })
2407    }
2408}
2409
2410// -- RFC6066 certificate status types
2411
2412/// Only supports OCSP
2413#[derive(Clone, Debug)]
2414pub(crate) struct CertificateStatus<'a> {
2415    pub(crate) ocsp_response: PayloadU24<'a>,
2416}
2417
2418impl<'a> Codec<'a> for CertificateStatus<'a> {
2419    fn encode(&self, bytes: &mut Vec<u8>) {
2420        CertificateStatusType::OCSP.encode(bytes);
2421        self.ocsp_response.encode(bytes);
2422    }
2423
2424    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2425        let typ = CertificateStatusType::read(r)?;
2426
2427        match typ {
2428            CertificateStatusType::OCSP => Ok(Self {
2429                ocsp_response: PayloadU24::read(r)?,
2430            }),
2431            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2432        }
2433    }
2434}
2435
2436impl<'a> CertificateStatus<'a> {
2437    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2438        CertificateStatus {
2439            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2440        }
2441    }
2442
2443    #[cfg(feature = "tls12")]
2444    pub(crate) fn into_inner(self) -> Vec<u8> {
2445        self.ocsp_response.0.into_vec()
2446    }
2447
2448    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2449        CertificateStatus {
2450            ocsp_response: self.ocsp_response.into_owned(),
2451        }
2452    }
2453}
2454
2455// -- RFC8879 compressed certificates
2456
2457#[derive(Debug)]
2458pub(crate) struct CompressedCertificatePayload<'a> {
2459    pub(crate) alg: CertificateCompressionAlgorithm,
2460    pub(crate) uncompressed_len: u32,
2461    pub(crate) compressed: PayloadU24<'a>,
2462}
2463
2464impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2465    fn encode(&self, bytes: &mut Vec<u8>) {
2466        self.alg.encode(bytes);
2467        codec::u24(self.uncompressed_len).encode(bytes);
2468        self.compressed.encode(bytes);
2469    }
2470
2471    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2472        Ok(Self {
2473            alg: CertificateCompressionAlgorithm::read(r)?,
2474            uncompressed_len: codec::u24::read(r)?.0,
2475            compressed: PayloadU24::read(r)?,
2476        })
2477    }
2478}
2479
2480impl CompressedCertificatePayload<'_> {
2481    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2482        CompressedCertificatePayload {
2483            compressed: self.compressed.into_owned(),
2484            ..self
2485        }
2486    }
2487
2488    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2489        CompressedCertificatePayload {
2490            alg: self.alg,
2491            uncompressed_len: self.uncompressed_len,
2492            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2493        }
2494    }
2495}
2496
2497#[derive(Debug)]
2498pub(crate) enum HandshakePayload<'a> {
2499    HelloRequest,
2500    ClientHello(ClientHelloPayload),
2501    ServerHello(ServerHelloPayload),
2502    HelloRetryRequest(HelloRetryRequest),
2503    Certificate(CertificateChain<'a>),
2504    CertificateTls13(CertificatePayloadTls13<'a>),
2505    CompressedCertificate(CompressedCertificatePayload<'a>),
2506    ServerKeyExchange(ServerKeyExchangePayload),
2507    CertificateRequest(CertificateRequestPayload),
2508    CertificateRequestTls13(CertificateRequestPayloadTls13),
2509    CertificateVerify(DigitallySignedStruct),
2510    ServerHelloDone,
2511    EndOfEarlyData,
2512    ClientKeyExchange(Payload<'a>),
2513    NewSessionTicket(NewSessionTicketPayload),
2514    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2515    EncryptedExtensions(Box<ServerExtensions<'a>>),
2516    KeyUpdate(KeyUpdateRequest),
2517    Finished(Payload<'a>),
2518    CertificateStatus(CertificateStatus<'a>),
2519    MessageHash(Payload<'a>),
2520    Unknown((HandshakeType, Payload<'a>)),
2521}
2522
2523impl HandshakePayload<'_> {
2524    fn encode(&self, bytes: &mut Vec<u8>) {
2525        use self::HandshakePayload::*;
2526        match self {
2527            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2528            ClientHello(x) => x.encode(bytes),
2529            ServerHello(x) => x.encode(bytes),
2530            HelloRetryRequest(x) => x.encode(bytes),
2531            Certificate(x) => x.encode(bytes),
2532            CertificateTls13(x) => x.encode(bytes),
2533            CompressedCertificate(x) => x.encode(bytes),
2534            ServerKeyExchange(x) => x.encode(bytes),
2535            ClientKeyExchange(x) => x.encode(bytes),
2536            CertificateRequest(x) => x.encode(bytes),
2537            CertificateRequestTls13(x) => x.encode(bytes),
2538            CertificateVerify(x) => x.encode(bytes),
2539            NewSessionTicket(x) => x.encode(bytes),
2540            NewSessionTicketTls13(x) => x.encode(bytes),
2541            EncryptedExtensions(x) => x.encode(bytes),
2542            KeyUpdate(x) => x.encode(bytes),
2543            Finished(x) => x.encode(bytes),
2544            CertificateStatus(x) => x.encode(bytes),
2545            MessageHash(x) => x.encode(bytes),
2546            Unknown((_, x)) => x.encode(bytes),
2547        }
2548    }
2549
2550    pub(crate) fn handshake_type(&self) -> HandshakeType {
2551        use self::HandshakePayload::*;
2552        match self {
2553            HelloRequest => HandshakeType::HelloRequest,
2554            ClientHello(_) => HandshakeType::ClientHello,
2555            ServerHello(_) => HandshakeType::ServerHello,
2556            HelloRetryRequest(_) => HandshakeType::HelloRetryRequest,
2557            Certificate(_) | CertificateTls13(_) => HandshakeType::Certificate,
2558            CompressedCertificate(_) => HandshakeType::CompressedCertificate,
2559            ServerKeyExchange(_) => HandshakeType::ServerKeyExchange,
2560            CertificateRequest(_) | CertificateRequestTls13(_) => HandshakeType::CertificateRequest,
2561            CertificateVerify(_) => HandshakeType::CertificateVerify,
2562            ServerHelloDone => HandshakeType::ServerHelloDone,
2563            EndOfEarlyData => HandshakeType::EndOfEarlyData,
2564            ClientKeyExchange(_) => HandshakeType::ClientKeyExchange,
2565            NewSessionTicket(_) | NewSessionTicketTls13(_) => HandshakeType::NewSessionTicket,
2566            EncryptedExtensions(_) => HandshakeType::EncryptedExtensions,
2567            KeyUpdate(_) => HandshakeType::KeyUpdate,
2568            Finished(_) => HandshakeType::Finished,
2569            CertificateStatus(_) => HandshakeType::CertificateStatus,
2570            MessageHash(_) => HandshakeType::MessageHash,
2571            Unknown((t, _)) => *t,
2572        }
2573    }
2574
2575    fn wire_handshake_type(&self) -> HandshakeType {
2576        match self.handshake_type() {
2577            // A `HelloRetryRequest` appears on the wire as a `ServerHello` with a magic `random` value.
2578            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2579            other => other,
2580        }
2581    }
2582
2583    fn into_owned(self) -> HandshakePayload<'static> {
2584        use HandshakePayload::*;
2585
2586        match self {
2587            HelloRequest => HelloRequest,
2588            ClientHello(x) => ClientHello(x),
2589            ServerHello(x) => ServerHello(x),
2590            HelloRetryRequest(x) => HelloRetryRequest(x),
2591            Certificate(x) => Certificate(x.into_owned()),
2592            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2593            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2594            ServerKeyExchange(x) => ServerKeyExchange(x),
2595            CertificateRequest(x) => CertificateRequest(x),
2596            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2597            CertificateVerify(x) => CertificateVerify(x),
2598            ServerHelloDone => ServerHelloDone,
2599            EndOfEarlyData => EndOfEarlyData,
2600            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2601            NewSessionTicket(x) => NewSessionTicket(x),
2602            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2603            EncryptedExtensions(x) => EncryptedExtensions(Box::new(x.into_owned())),
2604            KeyUpdate(x) => KeyUpdate(x),
2605            Finished(x) => Finished(x.into_owned()),
2606            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2607            MessageHash(x) => MessageHash(x.into_owned()),
2608            Unknown((t, x)) => Unknown((t, x.into_owned())),
2609        }
2610    }
2611}
2612
2613#[derive(Debug)]
2614pub struct HandshakeMessagePayload<'a>(pub(crate) HandshakePayload<'a>);
2615
2616impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2617    fn encode(&self, bytes: &mut Vec<u8>) {
2618        self.payload_encode(bytes, Encoding::Standard);
2619    }
2620
2621    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2622        Self::read_version(r, ProtocolVersion::TLSv1_2)
2623    }
2624}
2625
2626impl<'a> HandshakeMessagePayload<'a> {
2627    pub(crate) fn read_version(
2628        r: &mut Reader<'a>,
2629        vers: ProtocolVersion,
2630    ) -> Result<Self, InvalidMessage> {
2631        let typ = HandshakeType::read(r)?;
2632        let len = codec::u24::read(r)?.0 as usize;
2633        let mut sub = r.sub(len)?;
2634
2635        let payload = match typ {
2636            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2637            HandshakeType::ClientHello => {
2638                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2639            }
2640            HandshakeType::ServerHello => {
2641                let version = ProtocolVersion::read(&mut sub)?;
2642                let random = Random::read(&mut sub)?;
2643
2644                if random == HELLO_RETRY_REQUEST_RANDOM {
2645                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2646                    hrr.legacy_version = version;
2647                    HandshakePayload::HelloRetryRequest(hrr)
2648                } else {
2649                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2650                    shp.legacy_version = version;
2651                    shp.random = random;
2652                    HandshakePayload::ServerHello(shp)
2653                }
2654            }
2655            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2656                let p = CertificatePayloadTls13::read(&mut sub)?;
2657                HandshakePayload::CertificateTls13(p)
2658            }
2659            HandshakeType::Certificate => {
2660                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2661            }
2662            HandshakeType::ServerKeyExchange => {
2663                let p = ServerKeyExchangePayload::read(&mut sub)?;
2664                HandshakePayload::ServerKeyExchange(p)
2665            }
2666            HandshakeType::ServerHelloDone => {
2667                sub.expect_empty("ServerHelloDone")?;
2668                HandshakePayload::ServerHelloDone
2669            }
2670            HandshakeType::ClientKeyExchange => {
2671                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2672            }
2673            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2674                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2675                HandshakePayload::CertificateRequestTls13(p)
2676            }
2677            HandshakeType::CertificateRequest => {
2678                let p = CertificateRequestPayload::read(&mut sub)?;
2679                HandshakePayload::CertificateRequest(p)
2680            }
2681            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2682                CompressedCertificatePayload::read(&mut sub)?,
2683            ),
2684            HandshakeType::CertificateVerify => {
2685                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2686            }
2687            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2688                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2689                HandshakePayload::NewSessionTicketTls13(p)
2690            }
2691            HandshakeType::NewSessionTicket => {
2692                let p = NewSessionTicketPayload::read(&mut sub)?;
2693                HandshakePayload::NewSessionTicket(p)
2694            }
2695            HandshakeType::EncryptedExtensions => {
2696                HandshakePayload::EncryptedExtensions(Box::new(ServerExtensions::read(&mut sub)?))
2697            }
2698            HandshakeType::KeyUpdate => {
2699                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2700            }
2701            HandshakeType::EndOfEarlyData => {
2702                sub.expect_empty("EndOfEarlyData")?;
2703                HandshakePayload::EndOfEarlyData
2704            }
2705            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2706            HandshakeType::CertificateStatus => {
2707                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2708            }
2709            HandshakeType::MessageHash => {
2710                // does not appear on the wire
2711                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2712            }
2713            HandshakeType::HelloRetryRequest => {
2714                // not legal on wire
2715                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2716            }
2717            _ => HandshakePayload::Unknown((typ, Payload::read(&mut sub))),
2718        };
2719
2720        sub.expect_empty("HandshakeMessagePayload")
2721            .map(|_| Self(payload))
2722    }
2723
2724    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2725        let mut ret = self.get_encoding();
2726        let ret_len = ret.len() - self.total_binder_length();
2727        ret.truncate(ret_len);
2728        ret
2729    }
2730
2731    pub(crate) fn total_binder_length(&self) -> usize {
2732        match &self.0 {
2733            HandshakePayload::ClientHello(ch) => match &ch.preshared_key_offer {
2734                Some(offer) => {
2735                    let mut binders_encoding = Vec::new();
2736                    offer
2737                        .binders
2738                        .encode(&mut binders_encoding);
2739                    binders_encoding.len()
2740                }
2741                _ => 0,
2742            },
2743            _ => 0,
2744        }
2745    }
2746
2747    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2748        // output type, length, and encoded payload
2749        self.0
2750            .wire_handshake_type()
2751            .encode(bytes);
2752
2753        let nested = LengthPrefixedBuffer::new(
2754            ListLength::U24 {
2755                max: usize::MAX,
2756                error: InvalidMessage::MessageTooLarge,
2757            },
2758            bytes,
2759        );
2760
2761        match &self.0 {
2762            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2763            // differently based on the purpose of the encoding.
2764            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2765            HandshakePayload::HelloRetryRequest(payload) => {
2766                payload.payload_encode(nested.buf, encoding)
2767            }
2768
2769            // All other payload types are encoded the same regardless of purpose.
2770            _ => self.0.encode(nested.buf),
2771        }
2772    }
2773
2774    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2775        Self(HandshakePayload::MessageHash(Payload::new(hash.to_vec())))
2776    }
2777
2778    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2779        HandshakeMessagePayload(self.0.into_owned())
2780    }
2781}
2782
2783#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2784pub struct HpkeSymmetricCipherSuite {
2785    pub kdf_id: HpkeKdf,
2786    pub aead_id: HpkeAead,
2787}
2788
2789impl Codec<'_> for HpkeSymmetricCipherSuite {
2790    fn encode(&self, bytes: &mut Vec<u8>) {
2791        self.kdf_id.encode(bytes);
2792        self.aead_id.encode(bytes);
2793    }
2794
2795    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2796        Ok(Self {
2797            kdf_id: HpkeKdf::read(r)?,
2798            aead_id: HpkeAead::read(r)?,
2799        })
2800    }
2801}
2802
2803/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
2804impl TlsListElement for HpkeSymmetricCipherSuite {
2805    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
2806        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
2807    };
2808}
2809
2810#[derive(Clone, Debug, PartialEq)]
2811pub struct HpkeKeyConfig {
2812    pub config_id: u8,
2813    pub kem_id: HpkeKem,
2814    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
2815    pub public_key: PayloadU16<NonEmpty>,
2816    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2817}
2818
2819impl Codec<'_> for HpkeKeyConfig {
2820    fn encode(&self, bytes: &mut Vec<u8>) {
2821        self.config_id.encode(bytes);
2822        self.kem_id.encode(bytes);
2823        self.public_key.encode(bytes);
2824        self.symmetric_cipher_suites
2825            .encode(bytes);
2826    }
2827
2828    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2829        Ok(Self {
2830            config_id: u8::read(r)?,
2831            kem_id: HpkeKem::read(r)?,
2832            public_key: PayloadU16::read(r)?,
2833            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2834        })
2835    }
2836}
2837
2838#[derive(Clone, Debug, PartialEq)]
2839pub struct EchConfigContents {
2840    pub key_config: HpkeKeyConfig,
2841    pub maximum_name_length: u8,
2842    pub public_name: DnsName<'static>,
2843    pub extensions: Vec<EchConfigExtension>,
2844}
2845
2846impl EchConfigContents {
2847    /// Returns true if there is more than one extension of a given
2848    /// type.
2849    pub(crate) fn has_duplicate_extension(&self) -> bool {
2850        has_duplicates::<_, _, u16>(
2851            self.extensions
2852                .iter()
2853                .map(|ext| ext.ext_type()),
2854        )
2855    }
2856
2857    /// Returns true if there is at least one mandatory unsupported extension.
2858    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2859        self.extensions
2860            .iter()
2861            // An extension is considered mandatory if the high bit of its type is set.
2862            .any(|ext| {
2863                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2864                    && u16::from(ext.ext_type()) & 0x8000 != 0
2865            })
2866    }
2867}
2868
2869impl Codec<'_> for EchConfigContents {
2870    fn encode(&self, bytes: &mut Vec<u8>) {
2871        self.key_config.encode(bytes);
2872        self.maximum_name_length.encode(bytes);
2873        let dns_name = &self.public_name.borrow();
2874        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
2875        self.extensions.encode(bytes);
2876    }
2877
2878    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2879        Ok(Self {
2880            key_config: HpkeKeyConfig::read(r)?,
2881            maximum_name_length: u8::read(r)?,
2882            public_name: {
2883                DnsName::try_from(
2884                    PayloadU8::<MaybeEmpty>::read(r)?
2885                        .0
2886                        .as_slice(),
2887                )
2888                .map_err(|_| InvalidMessage::InvalidServerName)?
2889                .to_owned()
2890            },
2891            extensions: Vec::read(r)?,
2892        })
2893    }
2894}
2895
2896/// An encrypted client hello (ECH) config.
2897#[derive(Clone, Debug, PartialEq)]
2898pub enum EchConfigPayload {
2899    /// A recognized V18 ECH configuration.
2900    V18(EchConfigContents),
2901    /// An unknown version ECH configuration.
2902    Unknown {
2903        version: EchVersion,
2904        contents: PayloadU16,
2905    },
2906}
2907
2908impl TlsListElement for EchConfigPayload {
2909    const SIZE_LEN: ListLength = ListLength::U16;
2910}
2911
2912impl Codec<'_> for EchConfigPayload {
2913    fn encode(&self, bytes: &mut Vec<u8>) {
2914        match self {
2915            Self::V18(c) => {
2916                // Write the version, the length, and the contents.
2917                EchVersion::V18.encode(bytes);
2918                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2919                c.encode(inner.buf);
2920            }
2921            Self::Unknown { version, contents } => {
2922                // Unknown configuration versions are opaque.
2923                version.encode(bytes);
2924                contents.encode(bytes);
2925            }
2926        }
2927    }
2928
2929    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2930        let version = EchVersion::read(r)?;
2931        let length = u16::read(r)?;
2932        let mut contents = r.sub(length as usize)?;
2933
2934        Ok(match version {
2935            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2936            _ => {
2937                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2938                let data = PayloadU16::new(contents.rest().into());
2939                Self::Unknown {
2940                    version,
2941                    contents: data,
2942                }
2943            }
2944        })
2945    }
2946}
2947
2948#[derive(Clone, Debug, PartialEq)]
2949pub enum EchConfigExtension {
2950    Unknown(UnknownExtension),
2951}
2952
2953impl EchConfigExtension {
2954    pub(crate) fn ext_type(&self) -> ExtensionType {
2955        match self {
2956            Self::Unknown(r) => r.typ,
2957        }
2958    }
2959}
2960
2961impl Codec<'_> for EchConfigExtension {
2962    fn encode(&self, bytes: &mut Vec<u8>) {
2963        self.ext_type().encode(bytes);
2964
2965        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2966        match self {
2967            Self::Unknown(r) => r.encode(nested.buf),
2968        }
2969    }
2970
2971    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2972        let typ = ExtensionType::read(r)?;
2973        let len = u16::read(r)? as usize;
2974        let mut sub = r.sub(len)?;
2975
2976        #[allow(clippy::match_single_binding)] // Future-proofing.
2977        let ext = match typ {
2978            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2979        };
2980
2981        sub.expect_empty("EchConfigExtension")
2982            .map(|_| ext)
2983    }
2984}
2985
2986impl TlsListElement for EchConfigExtension {
2987    const SIZE_LEN: ListLength = ListLength::U16;
2988}
2989
2990/// Representation of the `ECHClientHello` client extension specified in
2991/// [draft-ietf-tls-esni Section 5].
2992///
2993/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
2994#[derive(Clone, Debug)]
2995pub(crate) enum EncryptedClientHello {
2996    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
2997    Outer(EncryptedClientHelloOuter),
2998    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
2999    ///
3000    /// This variant has no payload.
3001    Inner,
3002}
3003
3004impl Codec<'_> for EncryptedClientHello {
3005    fn encode(&self, bytes: &mut Vec<u8>) {
3006        match self {
3007            Self::Outer(payload) => {
3008                EchClientHelloType::ClientHelloOuter.encode(bytes);
3009                payload.encode(bytes);
3010            }
3011            Self::Inner => {
3012                EchClientHelloType::ClientHelloInner.encode(bytes);
3013                // Empty payload.
3014            }
3015        }
3016    }
3017
3018    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3019        match EchClientHelloType::read(r)? {
3020            EchClientHelloType::ClientHelloOuter => {
3021                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3022            }
3023            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3024            _ => Err(InvalidMessage::InvalidContentType),
3025        }
3026    }
3027}
3028
3029/// Representation of the ECHClientHello extension with type outer specified in
3030/// [draft-ietf-tls-esni Section 5].
3031///
3032/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3033#[derive(Clone, Debug)]
3034pub(crate) struct EncryptedClientHelloOuter {
3035    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3036    /// ECHConfigContents.cipher_suites list.
3037    pub cipher_suite: HpkeSymmetricCipherSuite,
3038    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3039    pub config_id: u8,
3040    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3041    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3042    pub enc: PayloadU16,
3043    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3044    pub payload: PayloadU16<NonEmpty>,
3045}
3046
3047impl Codec<'_> for EncryptedClientHelloOuter {
3048    fn encode(&self, bytes: &mut Vec<u8>) {
3049        self.cipher_suite.encode(bytes);
3050        self.config_id.encode(bytes);
3051        self.enc.encode(bytes);
3052        self.payload.encode(bytes);
3053    }
3054
3055    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3056        Ok(Self {
3057            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3058            config_id: u8::read(r)?,
3059            enc: PayloadU16::read(r)?,
3060            payload: PayloadU16::read(r)?,
3061        })
3062    }
3063}
3064
3065/// Representation of the ECHEncryptedExtensions extension specified in
3066/// [draft-ietf-tls-esni Section 5].
3067///
3068/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3069#[derive(Clone, Debug)]
3070pub(crate) struct ServerEncryptedClientHello {
3071    pub(crate) retry_configs: Vec<EchConfigPayload>,
3072}
3073
3074impl Codec<'_> for ServerEncryptedClientHello {
3075    fn encode(&self, bytes: &mut Vec<u8>) {
3076        self.retry_configs.encode(bytes);
3077    }
3078
3079    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3080        Ok(Self {
3081            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3082        })
3083    }
3084}
3085
3086/// The method of encoding to use for a handshake message.
3087///
3088/// In some cases a handshake message may be encoded differently depending on the purpose
3089/// the encoded message is being used for.
3090pub(crate) enum Encoding {
3091    /// Standard RFC 8446 encoding.
3092    Standard,
3093    /// Encoding for ECH confirmation for HRR.
3094    EchConfirmation,
3095    /// Encoding for ECH inner client hello.
3096    EchInnerHello { to_compress: Vec<ExtensionType> },
3097}
3098
3099fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3100    let mut seen = BTreeSet::new();
3101
3102    for x in iter {
3103        if !seen.insert(x.into()) {
3104            return true;
3105        }
3106    }
3107
3108    false
3109}
3110
3111struct DuplicateExtensionChecker(BTreeSet<u16>);
3112
3113impl DuplicateExtensionChecker {
3114    fn new() -> Self {
3115        Self(BTreeSet::new())
3116    }
3117
3118    fn check(&mut self, typ: ExtensionType) -> Result<(), InvalidMessage> {
3119        let u = u16::from(typ);
3120        match self.0.insert(u) {
3121            true => Ok(()),
3122            false => Err(InvalidMessage::DuplicateExtension(u)),
3123        }
3124    }
3125}
3126
3127fn low_quality_integer_hash(mut x: u32) -> u32 {
3128    x = x
3129        .wrapping_add(0x7ed55d16)
3130        .wrapping_add(x << 12);
3131    x = (x ^ 0xc761c23c) ^ (x >> 19);
3132    x = x
3133        .wrapping_add(0x165667b1)
3134        .wrapping_add(x << 5);
3135    x = x.wrapping_add(0xd3a2646c) ^ (x << 9);
3136    x = x
3137        .wrapping_add(0xfd7046c5)
3138        .wrapping_add(x << 3);
3139    x = (x ^ 0xb55a4f09) ^ (x >> 16);
3140    x
3141}
3142
3143#[cfg(test)]
3144mod tests {
3145    use super::*;
3146
3147    #[test]
3148    fn test_ech_config_dupe_exts() {
3149        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3150            typ: ExtensionType::Unknown(0x42),
3151            payload: Payload::new(vec![0x42]),
3152        });
3153        let mut config = config_template();
3154        config
3155            .extensions
3156            .push(unknown_ext.clone());
3157        config.extensions.push(unknown_ext);
3158
3159        assert!(config.has_duplicate_extension());
3160        assert!(!config.has_unknown_mandatory_extension());
3161    }
3162
3163    #[test]
3164    fn test_ech_config_mandatory_exts() {
3165        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3166            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3167            payload: Payload::new(vec![0x42]),
3168        });
3169        let mut config = config_template();
3170        config
3171            .extensions
3172            .push(mandatory_unknown_ext);
3173
3174        assert!(!config.has_duplicate_extension());
3175        assert!(config.has_unknown_mandatory_extension());
3176    }
3177
3178    fn config_template() -> EchConfigContents {
3179        EchConfigContents {
3180            key_config: HpkeKeyConfig {
3181                config_id: 0,
3182                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3183                public_key: PayloadU16::new(b"xxx".into()),
3184                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3185                    kdf_id: HpkeKdf::HKDF_SHA256,
3186                    aead_id: HpkeAead::AES_128_GCM,
3187                }],
3188            },
3189            maximum_name_length: 0,
3190            public_name: DnsName::try_from("example.com").unwrap(),
3191            extensions: vec![],
3192        }
3193    }
3194}