rustls/msgs/
handshake.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeSet;
3#[cfg(feature = "log")]
4use alloc::string::String;
5use alloc::vec;
6use alloc::vec::Vec;
7use core::ops::{Deref, DerefMut};
8use core::{fmt, iter};
9
10use pki_types::{CertificateDer, DnsName};
11
12use crate::crypto::{ActiveKeyExchange, SecureRandom};
13use crate::enums::{
14    CertificateCompressionAlgorithm, CertificateType, CipherSuite, EchClientHelloType,
15    HandshakeType, ProtocolVersion, SignatureScheme,
16};
17use crate::error::InvalidMessage;
18use crate::ffdhe_groups::FfdheGroup;
19use crate::log::warn;
20use crate::msgs::base::{MaybeEmpty, NonEmpty, Payload, PayloadU8, PayloadU16, PayloadU24};
21use crate::msgs::codec::{
22    self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement, TlsListIter,
23};
24use crate::msgs::enums::{
25    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
26    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
27    PskKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::sync::Arc;
31use crate::verify::DigitallySignedStruct;
32use crate::x509::wrap_in_sequence;
33
34/// Create a newtype wrapper around a given type.
35///
36/// This is used to create newtypes for the various TLS message types which is used to wrap
37/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38/// anything other than access to the underlying bytes.
39macro_rules! wrapped_payload(
40  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident$(<$inner_ty:ty>)?,) => {
41    $(#[$comment])*
42    #[derive(Clone, Debug)]
43    $vis struct $name($inner$(<$inner_ty>)?);
44
45    impl From<Vec<u8>> for $name {
46        fn from(v: Vec<u8>) -> Self {
47            Self($inner::new(v))
48        }
49    }
50
51    impl AsRef<[u8]> for $name {
52        fn as_ref(&self) -> &[u8] {
53            self.0.0.as_slice()
54        }
55    }
56
57    impl Codec<'_> for $name {
58        fn encode(&self, bytes: &mut Vec<u8>) {
59            self.0.encode(bytes);
60        }
61
62        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63            Ok(Self($inner::read(r)?))
64        }
65    }
66  }
67);
68
69#[derive(Clone, Copy, Eq, PartialEq)]
70pub(crate) struct Random(pub(crate) [u8; 32]);
71
72impl fmt::Debug for Random {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        super::base::hex(f, &self.0)
75    }
76}
77
78static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81]);
82
83static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85impl Codec<'_> for Random {
86    fn encode(&self, bytes: &mut Vec<u8>) {
87        bytes.extend_from_slice(&self.0);
88    }
89
90    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91        let Some(bytes) = r.take(32) else {
92            return Err(InvalidMessage::MissingData("Random"));
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub(crate) struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(self.as_ref());
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let Some(bytes) = r.take(len) else {
157            return Err(InvalidMessage::MissingData("SessionID"));
158        };
159
160        let mut out = [0u8; 32];
161        out[..len].clone_from_slice(&bytes[..len]);
162        Ok(Self { data: out, len })
163    }
164}
165
166impl SessionId {
167    pub(crate) fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168        let mut data = [0u8; 32];
169        secure_random.fill(&mut data)?;
170        Ok(Self { data, len: 32 })
171    }
172
173    pub(crate) fn empty() -> Self {
174        Self {
175            data: [0u8; 32],
176            len: 0,
177        }
178    }
179
180    pub(crate) fn is_empty(&self) -> bool {
181        self.len == 0
182    }
183}
184
185impl AsRef<[u8]> for SessionId {
186    fn as_ref(&self) -> &[u8] {
187        &self.data[..self.len]
188    }
189}
190
191#[derive(Clone, Debug, PartialEq)]
192pub(crate) struct UnknownExtension {
193    pub(crate) typ: ExtensionType,
194    pub(crate) payload: Payload<'static>,
195}
196
197impl UnknownExtension {
198    fn encode(&self, bytes: &mut Vec<u8>) {
199        self.payload.encode(bytes);
200    }
201
202    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
203        let payload = Payload::read(r).into_owned();
204        Self { typ, payload }
205    }
206}
207
208#[derive(Clone, Copy, Debug)]
209pub(crate) struct SupportedEcPointFormats {
210    pub(crate) uncompressed: bool,
211}
212
213impl Codec<'_> for SupportedEcPointFormats {
214    fn encode(&self, bytes: &mut Vec<u8>) {
215        let inner = LengthPrefixedBuffer::new(ECPointFormat::SIZE_LEN, bytes);
216
217        if self.uncompressed {
218            ECPointFormat::Uncompressed.encode(inner.buf);
219        }
220    }
221
222    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
223        let mut uncompressed = false;
224
225        for pf in TlsListIter::<ECPointFormat>::new(r)? {
226            if let ECPointFormat::Uncompressed = pf? {
227                uncompressed = true;
228            }
229        }
230
231        Ok(Self { uncompressed })
232    }
233}
234
235impl Default for SupportedEcPointFormats {
236    fn default() -> Self {
237        Self { uncompressed: true }
238    }
239}
240
241/// RFC8422: `ECPointFormat ec_point_format_list<1..2^8-1>`
242impl TlsListElement for ECPointFormat {
243    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
244        empty_error: InvalidMessage::IllegalEmptyList("ECPointFormats"),
245    };
246}
247
248/// RFC8422: `NamedCurve named_curve_list<2..2^16-1>`
249impl TlsListElement for NamedGroup {
250    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
251        empty_error: InvalidMessage::IllegalEmptyList("NamedGroups"),
252    };
253}
254
255/// RFC8446: `SignatureScheme supported_signature_algorithms<2..2^16-2>;`
256impl TlsListElement for SignatureScheme {
257    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
258        empty_error: InvalidMessage::NoSignatureSchemes,
259    };
260}
261
262#[derive(Clone, Debug)]
263pub(crate) enum ServerNamePayload<'a> {
264    /// A successfully decoded value:
265    SingleDnsName(DnsName<'a>),
266
267    /// A DNS name which was actually an IP address
268    IpAddress,
269
270    /// A successfully decoded, but syntactically-invalid value.
271    Invalid,
272}
273
274impl ServerNamePayload<'_> {
275    fn into_owned(self) -> ServerNamePayload<'static> {
276        match self {
277            Self::SingleDnsName(d) => ServerNamePayload::SingleDnsName(d.to_owned()),
278            Self::IpAddress => ServerNamePayload::IpAddress,
279            Self::Invalid => ServerNamePayload::Invalid,
280        }
281    }
282
283    /// RFC6066: `ServerName server_name_list<1..2^16-1>`
284    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
285        empty_error: InvalidMessage::IllegalEmptyList("ServerNames"),
286    };
287}
288
289/// Simplified encoding/decoding for a `ServerName` extension payload to/from `DnsName`
290///
291/// This is possible because:
292///
293/// - the spec (RFC6066) disallows multiple names for a given name type
294/// - name types other than ServerNameType::HostName are not defined, and they and
295///   any data that follows them cannot be skipped over.
296impl<'a> Codec<'a> for ServerNamePayload<'a> {
297    fn encode(&self, bytes: &mut Vec<u8>) {
298        let server_name_list = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
299
300        let ServerNamePayload::SingleDnsName(dns_name) = self else {
301            return;
302        };
303
304        ServerNameType::HostName.encode(server_name_list.buf);
305        let name_slice = dns_name.as_ref().as_bytes();
306        (name_slice.len() as u16).encode(server_name_list.buf);
307        server_name_list
308            .buf
309            .extend_from_slice(name_slice);
310    }
311
312    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
313        let mut found = None;
314
315        let len = Self::SIZE_LEN.read(r)?;
316        let mut sub = r.sub(len)?;
317
318        while sub.any_left() {
319            let typ = ServerNameType::read(&mut sub)?;
320
321            let payload = match typ {
322                ServerNameType::HostName => HostNamePayload::read(&mut sub)?,
323                _ => {
324                    // Consume remainder of extension bytes.  Since the length of the item
325                    // is an unknown encoding, we cannot continue.
326                    sub.rest();
327                    break;
328                }
329            };
330
331            // "The ServerNameList MUST NOT contain more than one name of
332            // the same name_type." - RFC6066
333            if found.is_some() {
334                warn!("Illegal SNI extension: duplicate host_name received");
335                return Err(InvalidMessage::InvalidServerName);
336            }
337
338            found = match payload {
339                HostNamePayload::HostName(dns_name) => {
340                    Some(Self::SingleDnsName(dns_name.to_owned()))
341                }
342
343                HostNamePayload::IpAddress(_invalid) => {
344                    warn!(
345                        "Illegal SNI extension: ignoring IP address presented as hostname ({_invalid:?})"
346                    );
347                    Some(Self::IpAddress)
348                }
349
350                HostNamePayload::Invalid(_invalid) => {
351                    warn!(
352                        "Illegal SNI hostname received {:?}",
353                        String::from_utf8_lossy(&_invalid.0)
354                    );
355                    Some(Self::Invalid)
356                }
357            };
358        }
359
360        Ok(found.unwrap_or(Self::Invalid))
361    }
362}
363
364impl<'a> From<&DnsName<'a>> for ServerNamePayload<'static> {
365    fn from(value: &DnsName<'a>) -> Self {
366        Self::SingleDnsName(trim_hostname_trailing_dot_for_sni(value))
367    }
368}
369
370#[derive(Clone, Debug)]
371pub(crate) enum HostNamePayload {
372    HostName(DnsName<'static>),
373    IpAddress(PayloadU16<NonEmpty>),
374    Invalid(PayloadU16<NonEmpty>),
375}
376
377impl HostNamePayload {
378    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
379        use pki_types::ServerName;
380        let raw = PayloadU16::<NonEmpty>::read(r)?;
381
382        match ServerName::try_from(raw.0.as_slice()) {
383            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
384            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
385            Ok(_) | Err(_) => Ok(Self::Invalid(raw)),
386        }
387    }
388}
389
390wrapped_payload!(
391    /// RFC7301: `opaque ProtocolName<1..2^8-1>;`
392    pub(crate) struct ProtocolName, PayloadU8<NonEmpty>,
393);
394
395impl PartialEq for ProtocolName {
396    fn eq(&self, other: &Self) -> bool {
397        self.0 == other.0
398    }
399}
400
401impl Deref for ProtocolName {
402    type Target = [u8];
403
404    fn deref(&self) -> &Self::Target {
405        self.as_ref()
406    }
407}
408
409/// RFC7301: `ProtocolName protocol_name_list<2..2^16-1>`
410impl TlsListElement for ProtocolName {
411    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
412        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
413    };
414}
415
416/// RFC7301 encodes a single protocol name as `Vec<ProtocolName>`
417#[derive(Clone, Debug)]
418pub(crate) struct SingleProtocolName(ProtocolName);
419
420impl SingleProtocolName {
421    pub(crate) fn new(single: ProtocolName) -> Self {
422        Self(single)
423    }
424
425    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
426        empty_error: InvalidMessage::IllegalEmptyList("ProtocolNames"),
427    };
428}
429
430impl Codec<'_> for SingleProtocolName {
431    fn encode(&self, bytes: &mut Vec<u8>) {
432        let body = LengthPrefixedBuffer::new(Self::SIZE_LEN, bytes);
433        self.0.encode(body.buf);
434    }
435
436    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
437        let len = Self::SIZE_LEN.read(reader)?;
438        let mut sub = reader.sub(len)?;
439
440        let item = ProtocolName::read(&mut sub)?;
441
442        if sub.any_left() {
443            Err(InvalidMessage::TrailingData("SingleProtocolName"))
444        } else {
445            Ok(Self(item))
446        }
447    }
448}
449
450impl AsRef<ProtocolName> for SingleProtocolName {
451    fn as_ref(&self) -> &ProtocolName {
452        &self.0
453    }
454}
455
456// --- TLS 1.3 Key shares ---
457#[derive(Clone, Debug)]
458pub(crate) struct KeyShareEntry {
459    pub(crate) group: NamedGroup,
460    /// RFC8446: `opaque key_exchange<1..2^16-1>;`
461    pub(crate) payload: PayloadU16<NonEmpty>,
462}
463
464impl KeyShareEntry {
465    pub(crate) fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
466        Self {
467            group,
468            payload: PayloadU16::new(payload.into()),
469        }
470    }
471}
472
473impl Codec<'_> for KeyShareEntry {
474    fn encode(&self, bytes: &mut Vec<u8>) {
475        self.group.encode(bytes);
476        self.payload.encode(bytes);
477    }
478
479    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
480        let group = NamedGroup::read(r)?;
481        let payload = PayloadU16::read(r)?;
482
483        Ok(Self { group, payload })
484    }
485}
486
487// --- TLS 1.3 PresharedKey offers ---
488#[derive(Clone, Debug)]
489pub(crate) struct PresharedKeyIdentity {
490    /// RFC8446: `opaque identity<1..2^16-1>;`
491    pub(crate) identity: PayloadU16<NonEmpty>,
492    pub(crate) obfuscated_ticket_age: u32,
493}
494
495impl PresharedKeyIdentity {
496    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
497        Self {
498            identity: PayloadU16::new(id),
499            obfuscated_ticket_age: age,
500        }
501    }
502}
503
504impl Codec<'_> for PresharedKeyIdentity {
505    fn encode(&self, bytes: &mut Vec<u8>) {
506        self.identity.encode(bytes);
507        self.obfuscated_ticket_age.encode(bytes);
508    }
509
510    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
511        Ok(Self {
512            identity: PayloadU16::read(r)?,
513            obfuscated_ticket_age: u32::read(r)?,
514        })
515    }
516}
517
518/// RFC8446: `PskIdentity identities<7..2^16-1>;`
519impl TlsListElement for PresharedKeyIdentity {
520    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
521        empty_error: InvalidMessage::IllegalEmptyList("PskIdentities"),
522    };
523}
524
525wrapped_payload!(
526    /// RFC8446: `opaque PskBinderEntry<32..255>;`
527    pub(crate) struct PresharedKeyBinder, PayloadU8<NonEmpty>,
528);
529
530/// RFC8446: `PskBinderEntry binders<33..2^16-1>;`
531impl TlsListElement for PresharedKeyBinder {
532    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
533        empty_error: InvalidMessage::IllegalEmptyList("PskBinders"),
534    };
535}
536
537#[derive(Clone, Debug)]
538pub(crate) struct PresharedKeyOffer {
539    pub(crate) identities: Vec<PresharedKeyIdentity>,
540    pub(crate) binders: Vec<PresharedKeyBinder>,
541}
542
543impl PresharedKeyOffer {
544    /// Make a new one with one entry.
545    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
546        Self {
547            identities: vec![id],
548            binders: vec![PresharedKeyBinder::from(binder)],
549        }
550    }
551}
552
553impl Codec<'_> for PresharedKeyOffer {
554    fn encode(&self, bytes: &mut Vec<u8>) {
555        self.identities.encode(bytes);
556        self.binders.encode(bytes);
557    }
558
559    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
560        Ok(Self {
561            identities: Vec::read(r)?,
562            binders: Vec::read(r)?,
563        })
564    }
565}
566
567// --- RFC6066 certificate status request ---
568wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
569
570/// RFC6066: `ResponderID responder_id_list<0..2^16-1>;`
571impl TlsListElement for ResponderId {
572    const SIZE_LEN: ListLength = ListLength::U16;
573}
574
575#[derive(Clone, Debug)]
576pub(crate) struct OcspCertificateStatusRequest {
577    pub(crate) responder_ids: Vec<ResponderId>,
578    pub(crate) extensions: PayloadU16,
579}
580
581impl Codec<'_> for OcspCertificateStatusRequest {
582    fn encode(&self, bytes: &mut Vec<u8>) {
583        CertificateStatusType::OCSP.encode(bytes);
584        self.responder_ids.encode(bytes);
585        self.extensions.encode(bytes);
586    }
587
588    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
589        Ok(Self {
590            responder_ids: Vec::read(r)?,
591            extensions: PayloadU16::read(r)?,
592        })
593    }
594}
595
596#[derive(Clone, Debug)]
597pub(crate) enum CertificateStatusRequest {
598    Ocsp(OcspCertificateStatusRequest),
599    Unknown((CertificateStatusType, Payload<'static>)),
600}
601
602impl Codec<'_> for CertificateStatusRequest {
603    fn encode(&self, bytes: &mut Vec<u8>) {
604        match self {
605            Self::Ocsp(r) => r.encode(bytes),
606            Self::Unknown((typ, payload)) => {
607                typ.encode(bytes);
608                payload.encode(bytes);
609            }
610        }
611    }
612
613    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
614        let typ = CertificateStatusType::read(r)?;
615
616        match typ {
617            CertificateStatusType::OCSP => {
618                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
619                Ok(Self::Ocsp(ocsp_req))
620            }
621            _ => {
622                let data = Payload::read(r).into_owned();
623                Ok(Self::Unknown((typ, data)))
624            }
625        }
626    }
627}
628
629impl CertificateStatusRequest {
630    pub(crate) fn build_ocsp() -> Self {
631        let ocsp = OcspCertificateStatusRequest {
632            responder_ids: Vec::new(),
633            extensions: PayloadU16::empty(),
634        };
635        Self::Ocsp(ocsp)
636    }
637}
638
639// ---
640
641/// RFC8446: `PskKeyExchangeMode ke_modes<1..255>;`
642#[derive(Clone, Copy, Debug, Default)]
643pub(crate) struct PskKeyExchangeModes {
644    pub(crate) psk_dhe: bool,
645    pub(crate) psk: bool,
646}
647
648impl Codec<'_> for PskKeyExchangeModes {
649    fn encode(&self, bytes: &mut Vec<u8>) {
650        let inner = LengthPrefixedBuffer::new(PskKeyExchangeMode::SIZE_LEN, bytes);
651        if self.psk_dhe {
652            PskKeyExchangeMode::PSK_DHE_KE.encode(inner.buf);
653        }
654        if self.psk {
655            PskKeyExchangeMode::PSK_KE.encode(inner.buf);
656        }
657    }
658
659    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
660        let mut psk_dhe = false;
661        let mut psk = false;
662
663        for ke in TlsListIter::<PskKeyExchangeMode>::new(reader)? {
664            match ke? {
665                PskKeyExchangeMode::PSK_DHE_KE => psk_dhe = true,
666                PskKeyExchangeMode::PSK_KE => psk = true,
667                _ => continue,
668            };
669        }
670
671        Ok(Self { psk_dhe, psk })
672    }
673}
674
675impl TlsListElement for PskKeyExchangeMode {
676    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
677        empty_error: InvalidMessage::IllegalEmptyList("PskKeyExchangeModes"),
678    };
679}
680
681/// RFC8446: `KeyShareEntry client_shares<0..2^16-1>;`
682impl TlsListElement for KeyShareEntry {
683    const SIZE_LEN: ListLength = ListLength::U16;
684}
685
686/// The body of the `SupportedVersions` extension when it appears in a
687/// `ClientHello`
688///
689/// This is documented as a preference-order vector, but we (as a server)
690/// ignore the preference of the client.
691///
692/// RFC8446: `ProtocolVersion versions<2..254>;`
693#[derive(Clone, Copy, Debug, Default)]
694pub(crate) struct SupportedProtocolVersions {
695    pub(crate) tls13: bool,
696    pub(crate) tls12: bool,
697}
698
699impl SupportedProtocolVersions {
700    /// Return true if `filter` returns true for any enabled version.
701    pub(crate) fn any(&self, filter: impl Fn(ProtocolVersion) -> bool) -> bool {
702        if self.tls13 && filter(ProtocolVersion::TLSv1_3) {
703            return true;
704        }
705        if self.tls12 && filter(ProtocolVersion::TLSv1_2) {
706            return true;
707        }
708        false
709    }
710
711    const LIST_LENGTH: ListLength = ListLength::NonZeroU8 {
712        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
713    };
714}
715
716impl Codec<'_> for SupportedProtocolVersions {
717    fn encode(&self, bytes: &mut Vec<u8>) {
718        let inner = LengthPrefixedBuffer::new(Self::LIST_LENGTH, bytes);
719        if self.tls13 {
720            ProtocolVersion::TLSv1_3.encode(inner.buf);
721        }
722        if self.tls12 {
723            ProtocolVersion::TLSv1_2.encode(inner.buf);
724        }
725    }
726
727    fn read(reader: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
728        let mut tls12 = false;
729        let mut tls13 = false;
730
731        for pv in TlsListIter::<ProtocolVersion>::new(reader)? {
732            match pv? {
733                ProtocolVersion::TLSv1_3 => tls13 = true,
734                ProtocolVersion::TLSv1_2 => tls12 = true,
735                _ => continue,
736            };
737        }
738
739        Ok(Self { tls13, tls12 })
740    }
741}
742
743impl TlsListElement for ProtocolVersion {
744    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
745        empty_error: InvalidMessage::IllegalEmptyList("ProtocolVersions"),
746    };
747}
748
749/// RFC7250: `CertificateType client_certificate_types<1..2^8-1>;`
750///
751/// Ditto `CertificateType server_certificate_types<1..2^8-1>;`
752impl TlsListElement for CertificateType {
753    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
754        empty_error: InvalidMessage::IllegalEmptyList("CertificateTypes"),
755    };
756}
757
758/// RFC8879: `CertificateCompressionAlgorithm algorithms<2..2^8-2>;`
759impl TlsListElement for CertificateCompressionAlgorithm {
760    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
761        empty_error: InvalidMessage::IllegalEmptyList("CertificateCompressionAlgorithms"),
762    };
763}
764
765/// A precursor to `ClientExtensions`, allowing customisation.
766///
767/// This is smaller than `ClientExtensions`, as it only contains the extensions
768/// we need to vary between different protocols (eg, TCP-TLS versus QUIC).
769#[derive(Clone, Default)]
770pub(crate) struct ClientExtensionsInput<'a> {
771    /// QUIC transport parameters
772    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
773
774    /// ALPN protocols
775    pub(crate) protocols: Option<Vec<ProtocolName>>,
776}
777
778impl ClientExtensionsInput<'_> {
779    pub(crate) fn from_alpn(alpn_protocols: Vec<Vec<u8>>) -> ClientExtensionsInput<'static> {
780        let protocols = match alpn_protocols.is_empty() {
781            true => None,
782            false => Some(
783                alpn_protocols
784                    .into_iter()
785                    .map(ProtocolName::from)
786                    .collect::<Vec<_>>(),
787            ),
788        };
789
790        ClientExtensionsInput {
791            transport_parameters: None,
792            protocols,
793        }
794    }
795
796    pub(crate) fn into_owned(self) -> ClientExtensionsInput<'static> {
797        let Self {
798            transport_parameters,
799            protocols,
800        } = self;
801        ClientExtensionsInput {
802            transport_parameters: transport_parameters.map(|x| x.into_owned()),
803            protocols,
804        }
805    }
806}
807
808#[derive(Clone)]
809pub(crate) enum TransportParameters<'a> {
810    /// QUIC transport parameters (RFC9001)
811    Quic(Payload<'a>),
812}
813
814impl TransportParameters<'_> {
815    pub(crate) fn into_owned(self) -> TransportParameters<'static> {
816        match self {
817            Self::Quic(v) => TransportParameters::Quic(v.into_owned()),
818        }
819    }
820}
821
822extension_struct! {
823    /// A representation of extensions present in a `ClientHello` message
824    ///
825    /// All extensions are optional (by definition) so are represented with `Option<T>`.
826    ///
827    /// Some extensions have an empty value and are represented with Option<()>.
828    ///
829    /// Unknown extensions are dropped during parsing.
830    pub(crate) struct ClientExtensions<'a> {
831        /// Requested server name indication (RFC6066)
832        ExtensionType::ServerName =>
833            pub(crate) server_name: Option<ServerNamePayload<'a>>,
834
835        /// Certificate status is requested (RFC6066)
836        ExtensionType::StatusRequest =>
837            pub(crate) certificate_status_request: Option<CertificateStatusRequest>,
838
839        /// Supported groups (RFC4492/RFC8446)
840        ExtensionType::EllipticCurves =>
841            pub(crate) named_groups: Option<Vec<NamedGroup>>,
842
843        /// Supported EC point formats (RFC4492)
844        ExtensionType::ECPointFormats =>
845            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
846
847        /// Supported signature schemes (RFC5246/RFC8446)
848        ExtensionType::SignatureAlgorithms =>
849            pub(crate) signature_schemes: Option<Vec<SignatureScheme>>,
850
851        /// Offered ALPN protocols (RFC6066)
852        ExtensionType::ALProtocolNegotiation =>
853            pub(crate) protocols: Option<Vec<ProtocolName>>,
854
855        /// Available client certificate types (RFC7250)
856        ExtensionType::ClientCertificateType =>
857            pub(crate) client_certificate_types: Option<Vec<CertificateType>>,
858
859        /// Acceptable server certificate types (RFC7250)
860        ExtensionType::ServerCertificateType =>
861            pub(crate) server_certificate_types: Option<Vec<CertificateType>>,
862
863        /// Extended master secret is requested (RFC7627)
864        ExtensionType::ExtendedMasterSecret =>
865            pub(crate) extended_master_secret_request: Option<()>,
866
867        /// Offered certificate compression methods (RFC8879)
868        ExtensionType::CompressCertificate =>
869            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
870
871        /// Session ticket offer or request (RFC5077/RFC8446)
872        ExtensionType::SessionTicket =>
873            pub(crate) session_ticket: Option<ClientSessionTicket>,
874
875        /// Offered preshared keys (RFC8446)
876        ExtensionType::PreSharedKey =>
877            pub(crate) preshared_key_offer: Option<PresharedKeyOffer>,
878
879        /// Early data is requested (RFC8446)
880        ExtensionType::EarlyData =>
881            pub(crate) early_data_request: Option<()>,
882
883        /// Supported TLS versions (RFC8446)
884        ExtensionType::SupportedVersions =>
885            pub(crate) supported_versions: Option<SupportedProtocolVersions>,
886
887        /// Stateless HelloRetryRequest cookie (RFC8446)
888        ExtensionType::Cookie =>
889            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
890
891        /// Offered preshared key modes (RFC8446)
892        ExtensionType::PSKKeyExchangeModes =>
893            pub(crate) preshared_key_modes: Option<PskKeyExchangeModes>,
894
895        /// Certificate authority names (RFC8446)
896        ExtensionType::CertificateAuthorities =>
897            pub(crate) certificate_authority_names: Option<Vec<DistinguishedName>>,
898
899        /// Offered key exchange shares (RFC8446)
900        ExtensionType::KeyShare =>
901            pub(crate) key_shares: Option<Vec<KeyShareEntry>>,
902
903        /// QUIC transport parameters (RFC9001)
904        ExtensionType::TransportParameters =>
905            pub(crate) transport_parameters: Option<Payload<'a>>,
906
907        /// Secure renegotiation (RFC5746)
908        ExtensionType::RenegotiationInfo =>
909            pub(crate) renegotiation_info: Option<PayloadU8>,
910
911        /// Encrypted inner client hello (draft-ietf-tls-esni)
912        ExtensionType::EncryptedClientHello =>
913            pub(crate) encrypted_client_hello: Option<EncryptedClientHello>,
914
915        /// Encrypted client hello outer extensions (draft-ietf-tls-esni)
916        ExtensionType::EncryptedClientHelloOuterExtensions =>
917            pub(crate) encrypted_client_hello_outer: Option<Vec<ExtensionType>>,
918    } + {
919        /// Order randomization seed.
920        pub(crate) order_seed: u16,
921
922        /// Extensions that must appear contiguously.
923        pub(crate) contiguous_extensions: Vec<ExtensionType>,
924    }
925}
926
927impl ClientExtensions<'_> {
928    pub(crate) fn into_owned(self) -> ClientExtensions<'static> {
929        let Self {
930            server_name,
931            certificate_status_request,
932            named_groups,
933            ec_point_formats,
934            signature_schemes,
935            protocols,
936            client_certificate_types,
937            server_certificate_types,
938            extended_master_secret_request,
939            certificate_compression_algorithms,
940            session_ticket,
941            preshared_key_offer,
942            early_data_request,
943            supported_versions,
944            cookie,
945            preshared_key_modes,
946            certificate_authority_names,
947            key_shares,
948            transport_parameters,
949            renegotiation_info,
950            encrypted_client_hello,
951            encrypted_client_hello_outer,
952            order_seed,
953            contiguous_extensions,
954        } = self;
955        ClientExtensions {
956            server_name: server_name.map(|x| x.into_owned()),
957            certificate_status_request,
958            named_groups,
959            ec_point_formats,
960            signature_schemes,
961            protocols,
962            client_certificate_types,
963            server_certificate_types,
964            extended_master_secret_request,
965            certificate_compression_algorithms,
966            session_ticket,
967            preshared_key_offer,
968            early_data_request,
969            supported_versions,
970            cookie,
971            preshared_key_modes,
972            certificate_authority_names,
973            key_shares,
974            transport_parameters: transport_parameters.map(|x| x.into_owned()),
975            renegotiation_info,
976            encrypted_client_hello,
977            encrypted_client_hello_outer,
978            order_seed,
979            contiguous_extensions,
980        }
981    }
982
983    pub(crate) fn used_extensions_in_encoding_order(&self) -> Vec<ExtensionType> {
984        let mut exts = self.order_insensitive_extensions_in_random_order();
985        exts.extend(&self.contiguous_extensions);
986
987        if self
988            .encrypted_client_hello_outer
989            .is_some()
990        {
991            exts.push(ExtensionType::EncryptedClientHelloOuterExtensions);
992        }
993        if self.encrypted_client_hello.is_some() {
994            exts.push(ExtensionType::EncryptedClientHello);
995        }
996        if self.preshared_key_offer.is_some() {
997            exts.push(ExtensionType::PreSharedKey);
998        }
999        exts
1000    }
1001
1002    /// Returns extensions which don't need a specific order, in randomized order.
1003    ///
1004    /// Extensions are encoded in three portions:
1005    ///
1006    /// - First, extensions not otherwise dealt with by other cases.
1007    ///   These are encoded in random order, controlled by `self.order_seed`,
1008    ///   and this is the set of extensions returned by this function.
1009    ///
1010    /// - Second, extensions named in `self.contiguous_extensions`, in the order
1011    ///   given by that field.
1012    ///
1013    /// - Lastly, any ECH and PSK extensions (in that order).  These
1014    ///   are required to be last by the standard.
1015    fn order_insensitive_extensions_in_random_order(&self) -> Vec<ExtensionType> {
1016        let mut order = self.collect_used();
1017
1018        // Remove extensions which have specific order requirements.
1019        order.retain(|ext| {
1020            !(matches!(
1021                ext,
1022                ExtensionType::PreSharedKey
1023                    | ExtensionType::EncryptedClientHello
1024                    | ExtensionType::EncryptedClientHelloOuterExtensions
1025            ) || self.contiguous_extensions.contains(ext))
1026        });
1027
1028        order.sort_by_cached_key(|new_ext| {
1029            let seed = ((self.order_seed as u32) << 16) | (u16::from(*new_ext) as u32);
1030            low_quality_integer_hash(seed)
1031        });
1032
1033        order
1034    }
1035}
1036
1037impl<'a> Codec<'a> for ClientExtensions<'a> {
1038    fn encode(&self, bytes: &mut Vec<u8>) {
1039        let order = self.used_extensions_in_encoding_order();
1040
1041        if order.is_empty() {
1042            return;
1043        }
1044
1045        let body = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1046        for item in order {
1047            self.encode_one(item, body.buf);
1048        }
1049    }
1050
1051    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1052        let mut out = Self::default();
1053
1054        // extensions length can be absent if no extensions
1055        if !r.any_left() {
1056            return Ok(out);
1057        }
1058
1059        let mut checker = DuplicateExtensionChecker::new();
1060
1061        let len = usize::from(u16::read(r)?);
1062        let mut sub = r.sub(len)?;
1063
1064        while sub.any_left() {
1065            let typ = out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1066
1067            // PreSharedKey offer must come last
1068            if typ == ExtensionType::PreSharedKey && sub.any_left() {
1069                return Err(InvalidMessage::PreSharedKeyIsNotFinalExtension);
1070            }
1071        }
1072
1073        Ok(out)
1074    }
1075}
1076
1077fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
1078    let dns_name_str = dns_name.as_ref();
1079
1080    // RFC6066: "The hostname is represented as a byte string using
1081    // ASCII encoding without a trailing dot"
1082    if dns_name_str.ends_with('.') {
1083        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
1084        DnsName::try_from(trimmed)
1085            .unwrap()
1086            .to_owned()
1087    } else {
1088        dns_name.to_owned()
1089    }
1090}
1091
1092#[derive(Clone, Debug)]
1093pub(crate) enum ClientSessionTicket {
1094    Request,
1095    Offer(Payload<'static>),
1096}
1097
1098impl<'a> Codec<'a> for ClientSessionTicket {
1099    fn encode(&self, bytes: &mut Vec<u8>) {
1100        match self {
1101            Self::Request => (),
1102            Self::Offer(p) => p.encode(bytes),
1103        }
1104    }
1105
1106    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1107        Ok(match r.left() {
1108            0 => Self::Request,
1109            _ => Self::Offer(Payload::read(r).into_owned()),
1110        })
1111    }
1112}
1113
1114#[derive(Default)]
1115pub(crate) struct ServerExtensionsInput<'a> {
1116    /// QUIC transport parameters
1117    pub(crate) transport_parameters: Option<TransportParameters<'a>>,
1118}
1119
1120extension_struct! {
1121    pub(crate) struct ServerExtensions<'a> {
1122        /// Supported EC point formats (RFC4492)
1123        ExtensionType::ECPointFormats =>
1124            pub(crate) ec_point_formats: Option<SupportedEcPointFormats>,
1125
1126        /// Server name indication acknowledgement (RFC6066)
1127        ExtensionType::ServerName =>
1128            pub(crate) server_name_ack: Option<()>,
1129
1130        /// Session ticket acknowledgement (RFC5077)
1131        ExtensionType::SessionTicket =>
1132            pub(crate) session_ticket_ack: Option<()>,
1133
1134        ExtensionType::RenegotiationInfo =>
1135            pub(crate) renegotiation_info: Option<PayloadU8>,
1136
1137        /// Selected ALPN protocol (RFC7301)
1138        ExtensionType::ALProtocolNegotiation =>
1139            pub(crate) selected_protocol: Option<SingleProtocolName>,
1140
1141        /// Key exchange server share (RFC8446)
1142        ExtensionType::KeyShare =>
1143            pub(crate) key_share: Option<KeyShareEntry>,
1144
1145        /// Selected preshared key index (RFC8446)
1146        ExtensionType::PreSharedKey =>
1147            pub(crate) preshared_key: Option<u16>,
1148
1149        /// Required client certificate type (RFC7250)
1150        ExtensionType::ClientCertificateType =>
1151            pub(crate) client_certificate_type: Option<CertificateType>,
1152
1153        /// Selected server certificate type (RFC7250)
1154        ExtensionType::ServerCertificateType =>
1155            pub(crate) server_certificate_type: Option<CertificateType>,
1156
1157        /// Extended master secret is in use (RFC7627)
1158        ExtensionType::ExtendedMasterSecret =>
1159            pub(crate) extended_master_secret_ack: Option<()>,
1160
1161        /// Certificate status acknowledgement (RFC6066)
1162        ExtensionType::StatusRequest =>
1163            pub(crate) certificate_status_request_ack: Option<()>,
1164
1165        /// Selected TLS version (RFC8446)
1166        ExtensionType::SupportedVersions =>
1167            pub(crate) selected_version: Option<ProtocolVersion>,
1168
1169        /// QUIC transport parameters (RFC9001)
1170        ExtensionType::TransportParameters =>
1171            pub(crate) transport_parameters: Option<Payload<'a>>,
1172
1173        /// Early data is accepted (RFC8446)
1174        ExtensionType::EarlyData =>
1175            pub(crate) early_data_ack: Option<()>,
1176
1177        /// Encrypted inner client hello response (draft-ietf-tls-esni)
1178        ExtensionType::EncryptedClientHello =>
1179            pub(crate) encrypted_client_hello_ack: Option<ServerEncryptedClientHello>,
1180    } + {
1181        pub(crate) unknown_extensions: BTreeSet<u16>,
1182    }
1183}
1184
1185impl ServerExtensions<'_> {
1186    fn into_owned(self) -> ServerExtensions<'static> {
1187        let Self {
1188            ec_point_formats,
1189            server_name_ack,
1190            session_ticket_ack,
1191            renegotiation_info,
1192            selected_protocol,
1193            key_share,
1194            preshared_key,
1195            client_certificate_type,
1196            server_certificate_type,
1197            extended_master_secret_ack,
1198            certificate_status_request_ack,
1199            selected_version,
1200            transport_parameters,
1201            early_data_ack,
1202            encrypted_client_hello_ack,
1203            unknown_extensions,
1204        } = self;
1205        ServerExtensions {
1206            ec_point_formats,
1207            server_name_ack,
1208            session_ticket_ack,
1209            renegotiation_info,
1210            selected_protocol,
1211            key_share,
1212            preshared_key,
1213            client_certificate_type,
1214            server_certificate_type,
1215            extended_master_secret_ack,
1216            certificate_status_request_ack,
1217            selected_version,
1218            transport_parameters: transport_parameters.map(|x| x.into_owned()),
1219            early_data_ack,
1220            encrypted_client_hello_ack,
1221            unknown_extensions,
1222        }
1223    }
1224}
1225
1226impl<'a> Codec<'a> for ServerExtensions<'a> {
1227    fn encode(&self, bytes: &mut Vec<u8>) {
1228        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1229
1230        for ext in Self::ALL_EXTENSIONS {
1231            self.encode_one(*ext, extensions.buf);
1232        }
1233    }
1234
1235    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1236        let mut out = Self::default();
1237        let mut checker = DuplicateExtensionChecker::new();
1238
1239        let len = usize::from(u16::read(r)?);
1240        let mut sub = r.sub(len)?;
1241
1242        while sub.any_left() {
1243            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
1244        }
1245
1246        out.unknown_extensions = checker.0;
1247        Ok(out)
1248    }
1249}
1250
1251#[derive(Clone, Debug)]
1252pub(crate) struct ClientHelloPayload {
1253    pub(crate) client_version: ProtocolVersion,
1254    pub(crate) random: Random,
1255    pub(crate) session_id: SessionId,
1256    pub(crate) cipher_suites: Vec<CipherSuite>,
1257    pub(crate) compression_methods: Vec<Compression>,
1258    pub(crate) extensions: Box<ClientExtensions<'static>>,
1259}
1260
1261impl ClientHelloPayload {
1262    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
1263        let mut bytes = Vec::new();
1264        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
1265        bytes
1266    }
1267
1268    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1269        self.client_version.encode(bytes);
1270        self.random.encode(bytes);
1271
1272        match purpose {
1273            // SessionID is required to be empty in the encoded inner client hello.
1274            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
1275            _ => self.session_id.encode(bytes),
1276        }
1277
1278        self.cipher_suites.encode(bytes);
1279        self.compression_methods.encode(bytes);
1280
1281        let to_compress = match purpose {
1282            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
1283            _ => {
1284                self.extensions.encode(bytes);
1285                return;
1286            }
1287        };
1288
1289        let mut compressed = self.extensions.clone();
1290
1291        // First, eliminate the full-fat versions of the extensions
1292        for e in &to_compress {
1293            compressed.clear(*e);
1294        }
1295
1296        // Replace with the marker noting which extensions were elided.
1297        compressed.encrypted_client_hello_outer = Some(to_compress);
1298
1299        // And encode as normal.
1300        compressed.encode(bytes);
1301    }
1302
1303    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1304        self.key_shares
1305            .as_ref()
1306            .map(|entries| {
1307                has_duplicates::<_, _, u16>(
1308                    entries
1309                        .iter()
1310                        .map(|kse| u16::from(kse.group)),
1311                )
1312            })
1313            .unwrap_or_default()
1314    }
1315
1316    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1317        if let Some(algs) = &self.certificate_compression_algorithms {
1318            has_duplicates::<_, _, u16>(algs.iter().cloned())
1319        } else {
1320            false
1321        }
1322    }
1323}
1324
1325impl Codec<'_> for ClientHelloPayload {
1326    fn encode(&self, bytes: &mut Vec<u8>) {
1327        self.payload_encode(bytes, Encoding::Standard)
1328    }
1329
1330    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1331        let ret = Self {
1332            client_version: ProtocolVersion::read(r)?,
1333            random: Random::read(r)?,
1334            session_id: SessionId::read(r)?,
1335            cipher_suites: Vec::read(r)?,
1336            compression_methods: Vec::read(r)?,
1337            extensions: Box::new(ClientExtensions::read(r)?.into_owned()),
1338        };
1339
1340        match r.any_left() {
1341            true => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
1342            false => Ok(ret),
1343        }
1344    }
1345}
1346
1347impl Deref for ClientHelloPayload {
1348    type Target = ClientExtensions<'static>;
1349    fn deref(&self) -> &Self::Target {
1350        &self.extensions
1351    }
1352}
1353
1354impl DerefMut for ClientHelloPayload {
1355    fn deref_mut(&mut self) -> &mut Self::Target {
1356        &mut self.extensions
1357    }
1358}
1359
1360/// RFC8446: `CipherSuite cipher_suites<2..2^16-2>;`
1361impl TlsListElement for CipherSuite {
1362    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
1363        empty_error: InvalidMessage::IllegalEmptyList("CipherSuites"),
1364    };
1365}
1366
1367/// RFC5246: `CompressionMethod compression_methods<1..2^8-1>;`
1368impl TlsListElement for Compression {
1369    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1370        empty_error: InvalidMessage::IllegalEmptyList("Compressions"),
1371    };
1372}
1373
1374/// draft-ietf-tls-esni-17: `ExtensionType OuterExtensions<2..254>;`
1375impl TlsListElement for ExtensionType {
1376    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
1377        empty_error: InvalidMessage::IllegalEmptyList("ExtensionTypes"),
1378    };
1379}
1380
1381extension_struct! {
1382    /// A representation of extensions present in a `HelloRetryRequest` message
1383    pub(crate) struct HelloRetryRequestExtensions<'a> {
1384        ExtensionType::KeyShare =>
1385            pub(crate) key_share: Option<NamedGroup>,
1386
1387        ExtensionType::Cookie =>
1388            pub(crate) cookie: Option<PayloadU16<NonEmpty>>,
1389
1390        ExtensionType::SupportedVersions =>
1391            pub(crate) supported_versions: Option<ProtocolVersion>,
1392
1393        ExtensionType::EncryptedClientHello =>
1394            pub(crate) encrypted_client_hello: Option<Payload<'a>>,
1395    } + {
1396        /// Records decoding order of records, and controls encoding order.
1397        pub(crate) order: Option<Vec<ExtensionType>>,
1398    }
1399}
1400
1401impl HelloRetryRequestExtensions<'_> {
1402    fn into_owned(self) -> HelloRetryRequestExtensions<'static> {
1403        let Self {
1404            key_share,
1405            cookie,
1406            supported_versions,
1407            encrypted_client_hello,
1408            order,
1409        } = self;
1410        HelloRetryRequestExtensions {
1411            key_share,
1412            cookie,
1413            supported_versions,
1414            encrypted_client_hello: encrypted_client_hello.map(|x| x.into_owned()),
1415            order,
1416        }
1417    }
1418}
1419
1420impl<'a> Codec<'a> for HelloRetryRequestExtensions<'a> {
1421    fn encode(&self, bytes: &mut Vec<u8>) {
1422        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1423
1424        for ext in self
1425            .order
1426            .as_deref()
1427            .unwrap_or(Self::ALL_EXTENSIONS)
1428        {
1429            self.encode_one(*ext, extensions.buf);
1430        }
1431    }
1432
1433    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1434        let mut out = Self::default();
1435
1436        // we must record order, so re-encoding round trips.  this is needed,
1437        // unfortunately, for ECH HRR confirmation
1438        let mut order = vec![];
1439
1440        let len = usize::from(u16::read(r)?);
1441        let mut sub = r.sub(len)?;
1442
1443        while sub.any_left() {
1444            let typ = out.read_one(&mut sub, |_unk| {
1445                Err(InvalidMessage::UnknownHelloRetryRequestExtension)
1446            })?;
1447
1448            order.push(typ);
1449        }
1450
1451        out.order = Some(order);
1452        Ok(out)
1453    }
1454}
1455
1456#[derive(Clone, Debug)]
1457pub(crate) struct HelloRetryRequest {
1458    pub(crate) legacy_version: ProtocolVersion,
1459    pub(crate) session_id: SessionId,
1460    pub(crate) cipher_suite: CipherSuite,
1461    pub(crate) extensions: HelloRetryRequestExtensions<'static>,
1462}
1463
1464impl Codec<'_> for HelloRetryRequest {
1465    fn encode(&self, bytes: &mut Vec<u8>) {
1466        self.payload_encode(bytes, Encoding::Standard)
1467    }
1468
1469    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1470        let session_id = SessionId::read(r)?;
1471        let cipher_suite = CipherSuite::read(r)?;
1472        let compression = Compression::read(r)?;
1473
1474        if compression != Compression::Null {
1475            return Err(InvalidMessage::UnsupportedCompression);
1476        }
1477
1478        Ok(Self {
1479            legacy_version: ProtocolVersion::Unknown(0),
1480            session_id,
1481            cipher_suite,
1482            extensions: HelloRetryRequestExtensions::read(r)?.into_owned(),
1483        })
1484    }
1485}
1486
1487impl HelloRetryRequest {
1488    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1489        self.legacy_version.encode(bytes);
1490        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1491        self.session_id.encode(bytes);
1492        self.cipher_suite.encode(bytes);
1493        Compression::Null.encode(bytes);
1494
1495        match purpose {
1496            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1497            // must have its payload replaced by 8 zero bytes.
1498            //
1499            // See draft-ietf-tls-esni-18 7.2.1:
1500            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1501            Encoding::EchConfirmation
1502                if self
1503                    .extensions
1504                    .encrypted_client_hello
1505                    .is_some() =>
1506            {
1507                let hrr_confirmation = [0u8; 8];
1508                HelloRetryRequestExtensions {
1509                    encrypted_client_hello: Some(Payload::Borrowed(&hrr_confirmation)),
1510                    ..self.extensions.clone()
1511                }
1512                .encode(bytes);
1513            }
1514            _ => self.extensions.encode(bytes),
1515        }
1516    }
1517}
1518
1519impl Deref for HelloRetryRequest {
1520    type Target = HelloRetryRequestExtensions<'static>;
1521    fn deref(&self) -> &Self::Target {
1522        &self.extensions
1523    }
1524}
1525
1526impl DerefMut for HelloRetryRequest {
1527    fn deref_mut(&mut self) -> &mut Self::Target {
1528        &mut self.extensions
1529    }
1530}
1531
1532#[derive(Clone, Debug)]
1533pub(crate) struct ServerHelloPayload {
1534    pub(crate) legacy_version: ProtocolVersion,
1535    pub(crate) random: Random,
1536    pub(crate) session_id: SessionId,
1537    pub(crate) cipher_suite: CipherSuite,
1538    pub(crate) compression_method: Compression,
1539    pub(crate) extensions: Box<ServerExtensions<'static>>,
1540}
1541
1542impl Codec<'_> for ServerHelloPayload {
1543    fn encode(&self, bytes: &mut Vec<u8>) {
1544        self.payload_encode(bytes, Encoding::Standard)
1545    }
1546
1547    // minus version and random, which have already been read.
1548    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1549        let session_id = SessionId::read(r)?;
1550        let suite = CipherSuite::read(r)?;
1551        let compression = Compression::read(r)?;
1552
1553        // RFC5246:
1554        // "The presence of extensions can be detected by determining whether
1555        //  there are bytes following the compression_method field at the end of
1556        //  the ServerHello."
1557        let extensions = Box::new(
1558            if r.any_left() {
1559                ServerExtensions::read(r)?
1560            } else {
1561                ServerExtensions::default()
1562            }
1563            .into_owned(),
1564        );
1565
1566        let ret = Self {
1567            legacy_version: ProtocolVersion::Unknown(0),
1568            random: ZERO_RANDOM,
1569            session_id,
1570            cipher_suite: suite,
1571            compression_method: compression,
1572            extensions,
1573        };
1574
1575        r.expect_empty("ServerHelloPayload")
1576            .map(|_| ret)
1577    }
1578}
1579
1580impl ServerHelloPayload {
1581    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1582        debug_assert!(
1583            !matches!(encoding, Encoding::EchConfirmation),
1584            "we cannot compute an ECH confirmation on a received ServerHello"
1585        );
1586
1587        self.legacy_version.encode(bytes);
1588        self.random.encode(bytes);
1589        self.session_id.encode(bytes);
1590        self.cipher_suite.encode(bytes);
1591        self.compression_method.encode(bytes);
1592        self.extensions.encode(bytes);
1593    }
1594}
1595
1596impl Deref for ServerHelloPayload {
1597    type Target = ServerExtensions<'static>;
1598    fn deref(&self) -> &Self::Target {
1599        &self.extensions
1600    }
1601}
1602
1603impl DerefMut for ServerHelloPayload {
1604    fn deref_mut(&mut self) -> &mut Self::Target {
1605        &mut self.extensions
1606    }
1607}
1608
1609#[derive(Clone, Default, Debug)]
1610pub(crate) struct CertificateChain<'a>(pub(crate) Vec<CertificateDer<'a>>);
1611
1612impl CertificateChain<'_> {
1613    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1614        CertificateChain(
1615            self.0
1616                .into_iter()
1617                .map(|c| c.into_owned())
1618                .collect(),
1619        )
1620    }
1621}
1622
1623impl<'a> Codec<'a> for CertificateChain<'a> {
1624    fn encode(&self, bytes: &mut Vec<u8>) {
1625        Vec::encode(&self.0, bytes)
1626    }
1627
1628    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1629        Vec::read(r).map(Self)
1630    }
1631}
1632
1633impl<'a> Deref for CertificateChain<'a> {
1634    type Target = [CertificateDer<'a>];
1635
1636    fn deref(&self) -> &[CertificateDer<'a>] {
1637        &self.0
1638    }
1639}
1640
1641impl TlsListElement for CertificateDer<'_> {
1642    const SIZE_LEN: ListLength = ListLength::U24 {
1643        max: CERTIFICATE_MAX_SIZE_LIMIT,
1644        error: InvalidMessage::CertificatePayloadTooLarge,
1645    };
1646}
1647
1648/// TLS has a 16MB size limit on any handshake message,
1649/// plus a 16MB limit on any given certificate.
1650///
1651/// We contract that to 64KB to limit the amount of memory allocation
1652/// that is directly controllable by the peer.
1653pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1654
1655extension_struct! {
1656    pub(crate) struct CertificateExtensions<'a> {
1657        ExtensionType::StatusRequest =>
1658            pub(crate) status: Option<CertificateStatus<'a>>,
1659    }
1660}
1661
1662impl CertificateExtensions<'_> {
1663    fn into_owned(self) -> CertificateExtensions<'static> {
1664        CertificateExtensions {
1665            status: self.status.map(|s| s.into_owned()),
1666        }
1667    }
1668}
1669
1670impl<'a> Codec<'a> for CertificateExtensions<'a> {
1671    fn encode(&self, bytes: &mut Vec<u8>) {
1672        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1673
1674        for ext in Self::ALL_EXTENSIONS {
1675            self.encode_one(*ext, extensions.buf);
1676        }
1677    }
1678
1679    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1680        let mut out = Self::default();
1681
1682        let len = usize::from(u16::read(r)?);
1683        let mut sub = r.sub(len)?;
1684
1685        while sub.any_left() {
1686            out.read_one(&mut sub, |_unk| {
1687                Err(InvalidMessage::UnknownCertificateExtension)
1688            })?;
1689        }
1690
1691        Ok(out)
1692    }
1693}
1694
1695#[derive(Debug)]
1696pub(crate) struct CertificateEntry<'a> {
1697    pub(crate) cert: CertificateDer<'a>,
1698    pub(crate) extensions: CertificateExtensions<'a>,
1699}
1700
1701impl<'a> Codec<'a> for CertificateEntry<'a> {
1702    fn encode(&self, bytes: &mut Vec<u8>) {
1703        self.cert.encode(bytes);
1704        self.extensions.encode(bytes);
1705    }
1706
1707    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1708        Ok(Self {
1709            cert: CertificateDer::read(r)?,
1710            extensions: CertificateExtensions::read(r)?.into_owned(),
1711        })
1712    }
1713}
1714
1715impl<'a> CertificateEntry<'a> {
1716    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1717        Self {
1718            cert,
1719            extensions: CertificateExtensions::default(),
1720        }
1721    }
1722
1723    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1724        CertificateEntry {
1725            cert: self.cert.into_owned(),
1726            extensions: self.extensions.into_owned(),
1727        }
1728    }
1729}
1730
1731impl TlsListElement for CertificateEntry<'_> {
1732    const SIZE_LEN: ListLength = ListLength::U24 {
1733        max: CERTIFICATE_MAX_SIZE_LIMIT,
1734        error: InvalidMessage::CertificatePayloadTooLarge,
1735    };
1736}
1737
1738#[derive(Debug)]
1739pub(crate) struct CertificatePayloadTls13<'a> {
1740    pub(crate) context: PayloadU8,
1741    pub(crate) entries: Vec<CertificateEntry<'a>>,
1742}
1743
1744impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1745    fn encode(&self, bytes: &mut Vec<u8>) {
1746        self.context.encode(bytes);
1747        self.entries.encode(bytes);
1748    }
1749
1750    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1751        Ok(Self {
1752            context: PayloadU8::read(r)?,
1753            entries: Vec::read(r)?,
1754        })
1755    }
1756}
1757
1758impl<'a> CertificatePayloadTls13<'a> {
1759    pub(crate) fn new(
1760        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1761        ocsp_response: Option<&'a [u8]>,
1762    ) -> Self {
1763        Self {
1764            context: PayloadU8::empty(),
1765            entries: certs
1766                // zip certificate iterator with `ocsp_response` followed by
1767                // an infinite-length iterator of `None`.
1768                .zip(
1769                    ocsp_response
1770                        .into_iter()
1771                        .map(Some)
1772                        .chain(iter::repeat(None)),
1773                )
1774                .map(|(cert, ocsp)| {
1775                    let mut e = CertificateEntry::new(cert.clone());
1776                    if let Some(ocsp) = ocsp {
1777                        e.extensions.status = Some(CertificateStatus::new(ocsp));
1778                    }
1779                    e
1780                })
1781                .collect(),
1782        }
1783    }
1784
1785    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1786        CertificatePayloadTls13 {
1787            context: self.context,
1788            entries: self
1789                .entries
1790                .into_iter()
1791                .map(CertificateEntry::into_owned)
1792                .collect(),
1793        }
1794    }
1795
1796    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1797        let Some(entry) = self.entries.first() else {
1798            return vec![];
1799        };
1800        entry
1801            .extensions
1802            .status
1803            .as_ref()
1804            .map(|status| {
1805                status
1806                    .ocsp_response
1807                    .0
1808                    .clone()
1809                    .into_vec()
1810            })
1811            .unwrap_or_default()
1812    }
1813
1814    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1815        CertificateChain(
1816            self.entries
1817                .into_iter()
1818                .map(|e| e.cert)
1819                .collect(),
1820        )
1821    }
1822}
1823
1824/// Describes supported key exchange mechanisms.
1825#[derive(Clone, Copy, Debug, PartialEq)]
1826#[non_exhaustive]
1827pub enum KeyExchangeAlgorithm {
1828    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1829    ///
1830    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1831    DHE,
1832    /// Key exchange performed via elliptic curve Diffie-Hellman.
1833    ECDHE,
1834}
1835
1836pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1837    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1838
1839// We don't support arbitrary curves.  It's a terrible
1840// idea and unnecessary attack surface.  Please,
1841// get a grip.
1842#[derive(Debug)]
1843pub(crate) struct EcParameters {
1844    pub(crate) curve_type: ECCurveType,
1845    pub(crate) named_group: NamedGroup,
1846}
1847
1848impl Codec<'_> for EcParameters {
1849    fn encode(&self, bytes: &mut Vec<u8>) {
1850        self.curve_type.encode(bytes);
1851        self.named_group.encode(bytes);
1852    }
1853
1854    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1855        let ct = ECCurveType::read(r)?;
1856        if ct != ECCurveType::NamedCurve {
1857            return Err(InvalidMessage::UnsupportedCurveType);
1858        }
1859
1860        let grp = NamedGroup::read(r)?;
1861
1862        Ok(Self {
1863            curve_type: ct,
1864            named_group: grp,
1865        })
1866    }
1867}
1868
1869pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1870    /// Decode a key exchange message given the key_exchange `algo`
1871    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1872}
1873
1874#[derive(Debug)]
1875pub(crate) enum ClientKeyExchangeParams {
1876    Ecdh(ClientEcdhParams),
1877    Dh(ClientDhParams),
1878}
1879
1880impl ClientKeyExchangeParams {
1881    pub(crate) fn pub_key(&self) -> &[u8] {
1882        match self {
1883            Self::Ecdh(ecdh) => &ecdh.public.0,
1884            Self::Dh(dh) => &dh.public.0,
1885        }
1886    }
1887
1888    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1889        match self {
1890            Self::Ecdh(ecdh) => ecdh.encode(buf),
1891            Self::Dh(dh) => dh.encode(buf),
1892        }
1893    }
1894}
1895
1896impl KxDecode<'_> for ClientKeyExchangeParams {
1897    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1898        use KeyExchangeAlgorithm::*;
1899        Ok(match algo {
1900            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1901            DHE => Self::Dh(ClientDhParams::read(r)?),
1902        })
1903    }
1904}
1905
1906#[derive(Debug)]
1907pub(crate) struct ClientEcdhParams {
1908    /// RFC4492: `opaque point <1..2^8-1>;`
1909    pub(crate) public: PayloadU8<NonEmpty>,
1910}
1911
1912impl Codec<'_> for ClientEcdhParams {
1913    fn encode(&self, bytes: &mut Vec<u8>) {
1914        self.public.encode(bytes);
1915    }
1916
1917    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1918        let pb = PayloadU8::read(r)?;
1919        Ok(Self { public: pb })
1920    }
1921}
1922
1923#[derive(Debug)]
1924pub(crate) struct ClientDhParams {
1925    /// RFC5246: `opaque dh_Yc<1..2^16-1>;`
1926    pub(crate) public: PayloadU16<NonEmpty>,
1927}
1928
1929impl Codec<'_> for ClientDhParams {
1930    fn encode(&self, bytes: &mut Vec<u8>) {
1931        self.public.encode(bytes);
1932    }
1933
1934    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1935        Ok(Self {
1936            public: PayloadU16::read(r)?,
1937        })
1938    }
1939}
1940
1941#[derive(Debug)]
1942pub(crate) struct ServerEcdhParams {
1943    pub(crate) curve_params: EcParameters,
1944    /// RFC4492: `opaque point <1..2^8-1>;`
1945    pub(crate) public: PayloadU8<NonEmpty>,
1946}
1947
1948impl ServerEcdhParams {
1949    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1950        Self {
1951            curve_params: EcParameters {
1952                curve_type: ECCurveType::NamedCurve,
1953                named_group: kx.group(),
1954            },
1955            public: PayloadU8::new(kx.pub_key().to_vec()),
1956        }
1957    }
1958}
1959
1960impl Codec<'_> for ServerEcdhParams {
1961    fn encode(&self, bytes: &mut Vec<u8>) {
1962        self.curve_params.encode(bytes);
1963        self.public.encode(bytes);
1964    }
1965
1966    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1967        let cp = EcParameters::read(r)?;
1968        let pb = PayloadU8::read(r)?;
1969
1970        Ok(Self {
1971            curve_params: cp,
1972            public: pb,
1973        })
1974    }
1975}
1976
1977#[derive(Debug)]
1978#[allow(non_snake_case)]
1979pub(crate) struct ServerDhParams {
1980    /// RFC5246: `opaque dh_p<1..2^16-1>;`
1981    pub(crate) dh_p: PayloadU16<NonEmpty>,
1982    /// RFC5246: `opaque dh_g<1..2^16-1>;`
1983    pub(crate) dh_g: PayloadU16<NonEmpty>,
1984    /// RFC5246: `opaque dh_Ys<1..2^16-1>;`
1985    pub(crate) dh_Ys: PayloadU16<NonEmpty>,
1986}
1987
1988impl ServerDhParams {
1989    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1990        let Some(params) = kx.ffdhe_group() else {
1991            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
1992        };
1993
1994        Self {
1995            dh_p: PayloadU16::new(params.p.to_vec()),
1996            dh_g: PayloadU16::new(params.g.to_vec()),
1997            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1998        }
1999    }
2000
2001    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
2002        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
2003    }
2004}
2005
2006impl Codec<'_> for ServerDhParams {
2007    fn encode(&self, bytes: &mut Vec<u8>) {
2008        self.dh_p.encode(bytes);
2009        self.dh_g.encode(bytes);
2010        self.dh_Ys.encode(bytes);
2011    }
2012
2013    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2014        Ok(Self {
2015            dh_p: PayloadU16::read(r)?,
2016            dh_g: PayloadU16::read(r)?,
2017            dh_Ys: PayloadU16::read(r)?,
2018        })
2019    }
2020}
2021
2022#[allow(dead_code)]
2023#[derive(Debug)]
2024pub(crate) enum ServerKeyExchangeParams {
2025    Ecdh(ServerEcdhParams),
2026    Dh(ServerDhParams),
2027}
2028
2029impl ServerKeyExchangeParams {
2030    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
2031        match kx.group().key_exchange_algorithm() {
2032            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
2033            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
2034        }
2035    }
2036
2037    pub(crate) fn pub_key(&self) -> &[u8] {
2038        match self {
2039            Self::Ecdh(ecdh) => &ecdh.public.0,
2040            Self::Dh(dh) => &dh.dh_Ys.0,
2041        }
2042    }
2043
2044    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2045        match self {
2046            Self::Ecdh(ecdh) => ecdh.encode(buf),
2047            Self::Dh(dh) => dh.encode(buf),
2048        }
2049    }
2050}
2051
2052impl KxDecode<'_> for ServerKeyExchangeParams {
2053    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
2054        use KeyExchangeAlgorithm::*;
2055        Ok(match algo {
2056            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
2057            DHE => Self::Dh(ServerDhParams::read(r)?),
2058        })
2059    }
2060}
2061
2062#[derive(Debug)]
2063pub(crate) struct ServerKeyExchange {
2064    pub(crate) params: ServerKeyExchangeParams,
2065    pub(crate) dss: DigitallySignedStruct,
2066}
2067
2068impl ServerKeyExchange {
2069    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
2070        self.params.encode(buf);
2071        self.dss.encode(buf);
2072    }
2073}
2074
2075#[derive(Debug)]
2076pub(crate) enum ServerKeyExchangePayload {
2077    Known(ServerKeyExchange),
2078    Unknown(Payload<'static>),
2079}
2080
2081impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2082    fn from(value: ServerKeyExchange) -> Self {
2083        Self::Known(value)
2084    }
2085}
2086
2087impl Codec<'_> for ServerKeyExchangePayload {
2088    fn encode(&self, bytes: &mut Vec<u8>) {
2089        match self {
2090            Self::Known(x) => x.encode(bytes),
2091            Self::Unknown(x) => x.encode(bytes),
2092        }
2093    }
2094
2095    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2096        // read as Unknown, fully parse when we know the
2097        // KeyExchangeAlgorithm
2098        Ok(Self::Unknown(Payload::read(r).into_owned()))
2099    }
2100}
2101
2102impl ServerKeyExchangePayload {
2103    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2104        if let Self::Unknown(unk) = self {
2105            let mut rd = Reader::init(unk.bytes());
2106
2107            let result = ServerKeyExchange {
2108                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2109                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2110            };
2111
2112            if !rd.any_left() {
2113                return Some(result);
2114            };
2115        }
2116
2117        None
2118    }
2119}
2120
2121/// RFC5246: `ClientCertificateType certificate_types<1..2^8-1>;`
2122impl TlsListElement for ClientCertificateType {
2123    const SIZE_LEN: ListLength = ListLength::NonZeroU8 {
2124        empty_error: InvalidMessage::IllegalEmptyList("ClientCertificateTypes"),
2125    };
2126}
2127
2128wrapped_payload!(
2129    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2130    ///
2131    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2132    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2133    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2134    ///
2135    /// ```ignore
2136    /// for name in distinguished_names {
2137    ///     use x509_parser::prelude::FromDer;
2138    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2139    /// }
2140    /// ```
2141    ///
2142    /// The TLS encoding is defined in RFC5246: `opaque DistinguishedName<1..2^16-1>;`
2143    pub struct DistinguishedName,
2144    PayloadU16<NonEmpty>,
2145);
2146
2147impl DistinguishedName {
2148    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2149    ///
2150    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2151    ///
2152    /// ```ignore
2153    /// use x509_parser::prelude::FromDer;
2154    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2155    /// ```
2156    pub fn in_sequence(bytes: &[u8]) -> Self {
2157        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2158    }
2159}
2160
2161/// RFC8446: `DistinguishedName authorities<3..2^16-1>;` however,
2162/// RFC5246: `DistinguishedName certificate_authorities<0..2^16-1>;`
2163impl TlsListElement for DistinguishedName {
2164    const SIZE_LEN: ListLength = ListLength::U16;
2165}
2166
2167#[derive(Debug)]
2168pub(crate) struct CertificateRequestPayload {
2169    pub(crate) certtypes: Vec<ClientCertificateType>,
2170    pub(crate) sigschemes: Vec<SignatureScheme>,
2171    pub(crate) canames: Vec<DistinguishedName>,
2172}
2173
2174impl Codec<'_> for CertificateRequestPayload {
2175    fn encode(&self, bytes: &mut Vec<u8>) {
2176        self.certtypes.encode(bytes);
2177        self.sigschemes.encode(bytes);
2178        self.canames.encode(bytes);
2179    }
2180
2181    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2182        let certtypes = Vec::read(r)?;
2183        let sigschemes = Vec::read(r)?;
2184        let canames = Vec::read(r)?;
2185
2186        if sigschemes.is_empty() {
2187            warn!("meaningless CertificateRequest message");
2188            Err(InvalidMessage::NoSignatureSchemes)
2189        } else {
2190            Ok(Self {
2191                certtypes,
2192                sigschemes,
2193                canames,
2194            })
2195        }
2196    }
2197}
2198
2199extension_struct! {
2200    pub(crate) struct CertificateRequestExtensions {
2201        ExtensionType::SignatureAlgorithms =>
2202            pub(crate) signature_algorithms: Option<Vec<SignatureScheme>>,
2203
2204        ExtensionType::CertificateAuthorities =>
2205            pub(crate) authority_names: Option<Vec<DistinguishedName>>,
2206
2207        ExtensionType::CompressCertificate =>
2208            pub(crate) certificate_compression_algorithms: Option<Vec<CertificateCompressionAlgorithm>>,
2209    }
2210}
2211
2212impl Codec<'_> for CertificateRequestExtensions {
2213    fn encode(&self, bytes: &mut Vec<u8>) {
2214        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2215
2216        for ext in Self::ALL_EXTENSIONS {
2217            self.encode_one(*ext, extensions.buf);
2218        }
2219    }
2220
2221    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2222        let mut out = Self::default();
2223
2224        let mut checker = DuplicateExtensionChecker::new();
2225
2226        let len = usize::from(u16::read(r)?);
2227        let mut sub = r.sub(len)?;
2228
2229        while sub.any_left() {
2230            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2231        }
2232
2233        if out
2234            .signature_algorithms
2235            .as_ref()
2236            .map(|algs| algs.is_empty())
2237            .unwrap_or_default()
2238        {
2239            return Err(InvalidMessage::NoSignatureSchemes);
2240        }
2241
2242        Ok(out)
2243    }
2244}
2245
2246#[derive(Debug)]
2247pub(crate) struct CertificateRequestPayloadTls13 {
2248    pub(crate) context: PayloadU8,
2249    pub(crate) extensions: CertificateRequestExtensions,
2250}
2251
2252impl Codec<'_> for CertificateRequestPayloadTls13 {
2253    fn encode(&self, bytes: &mut Vec<u8>) {
2254        self.context.encode(bytes);
2255        self.extensions.encode(bytes);
2256    }
2257
2258    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2259        let context = PayloadU8::read(r)?;
2260        let extensions = CertificateRequestExtensions::read(r)?;
2261
2262        Ok(Self {
2263            context,
2264            extensions,
2265        })
2266    }
2267}
2268
2269// -- NewSessionTicket --
2270#[derive(Debug)]
2271pub(crate) struct NewSessionTicketPayload {
2272    pub(crate) lifetime_hint: u32,
2273    // Tickets can be large (KB), so we deserialise this straight
2274    // into an Arc, so it can be passed directly into the client's
2275    // session object without copying.
2276    pub(crate) ticket: Arc<PayloadU16>,
2277}
2278
2279impl NewSessionTicketPayload {
2280    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2281        Self {
2282            lifetime_hint,
2283            ticket: Arc::new(PayloadU16::new(ticket)),
2284        }
2285    }
2286}
2287
2288impl Codec<'_> for NewSessionTicketPayload {
2289    fn encode(&self, bytes: &mut Vec<u8>) {
2290        self.lifetime_hint.encode(bytes);
2291        self.ticket.encode(bytes);
2292    }
2293
2294    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2295        let lifetime = u32::read(r)?;
2296        let ticket = Arc::new(PayloadU16::read(r)?);
2297
2298        Ok(Self {
2299            lifetime_hint: lifetime,
2300            ticket,
2301        })
2302    }
2303}
2304
2305// -- NewSessionTicket electric boogaloo --
2306extension_struct! {
2307    pub(crate) struct NewSessionTicketExtensions {
2308        ExtensionType::EarlyData =>
2309            pub(crate) max_early_data_size: Option<u32>,
2310    }
2311}
2312
2313impl Codec<'_> for NewSessionTicketExtensions {
2314    fn encode(&self, bytes: &mut Vec<u8>) {
2315        let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2316
2317        for ext in Self::ALL_EXTENSIONS {
2318            self.encode_one(*ext, extensions.buf);
2319        }
2320    }
2321
2322    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2323        let mut out = Self::default();
2324
2325        let mut checker = DuplicateExtensionChecker::new();
2326
2327        let len = usize::from(u16::read(r)?);
2328        let mut sub = r.sub(len)?;
2329
2330        while sub.any_left() {
2331            out.read_one(&mut sub, |unknown| checker.check(unknown))?;
2332        }
2333
2334        Ok(out)
2335    }
2336}
2337
2338#[derive(Debug)]
2339pub(crate) struct NewSessionTicketPayloadTls13 {
2340    pub(crate) lifetime: u32,
2341    pub(crate) age_add: u32,
2342    pub(crate) nonce: PayloadU8,
2343    pub(crate) ticket: Arc<PayloadU16>,
2344    pub(crate) extensions: NewSessionTicketExtensions,
2345}
2346
2347impl NewSessionTicketPayloadTls13 {
2348    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: [u8; 32], ticket: Vec<u8>) -> Self {
2349        Self {
2350            lifetime,
2351            age_add,
2352            nonce: PayloadU8::new(nonce.to_vec()),
2353            ticket: Arc::new(PayloadU16::new(ticket)),
2354            extensions: NewSessionTicketExtensions::default(),
2355        }
2356    }
2357}
2358
2359impl Codec<'_> for NewSessionTicketPayloadTls13 {
2360    fn encode(&self, bytes: &mut Vec<u8>) {
2361        self.lifetime.encode(bytes);
2362        self.age_add.encode(bytes);
2363        self.nonce.encode(bytes);
2364        self.ticket.encode(bytes);
2365        self.extensions.encode(bytes);
2366    }
2367
2368    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2369        let lifetime = u32::read(r)?;
2370        let age_add = u32::read(r)?;
2371        let nonce = PayloadU8::read(r)?;
2372        // nb. RFC8446: `opaque ticket<1..2^16-1>;`
2373        let ticket = Arc::new(match PayloadU16::<NonEmpty>::read(r) {
2374            Err(InvalidMessage::IllegalEmptyValue) => Err(InvalidMessage::EmptyTicketValue),
2375            Err(err) => Err(err),
2376            Ok(pl) => Ok(PayloadU16::new(pl.0)),
2377        }?);
2378        let extensions = NewSessionTicketExtensions::read(r)?;
2379
2380        Ok(Self {
2381            lifetime,
2382            age_add,
2383            nonce,
2384            ticket,
2385            extensions,
2386        })
2387    }
2388}
2389
2390// -- RFC6066 certificate status types
2391
2392/// Only supports OCSP
2393#[derive(Clone, Debug)]
2394pub(crate) struct CertificateStatus<'a> {
2395    pub(crate) ocsp_response: PayloadU24<'a>,
2396}
2397
2398impl<'a> Codec<'a> for CertificateStatus<'a> {
2399    fn encode(&self, bytes: &mut Vec<u8>) {
2400        CertificateStatusType::OCSP.encode(bytes);
2401        self.ocsp_response.encode(bytes);
2402    }
2403
2404    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2405        let typ = CertificateStatusType::read(r)?;
2406
2407        match typ {
2408            CertificateStatusType::OCSP => Ok(Self {
2409                ocsp_response: PayloadU24::read(r)?,
2410            }),
2411            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2412        }
2413    }
2414}
2415
2416impl<'a> CertificateStatus<'a> {
2417    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2418        CertificateStatus {
2419            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2420        }
2421    }
2422
2423    pub(crate) fn into_inner(self) -> Vec<u8> {
2424        self.ocsp_response.0.into_vec()
2425    }
2426
2427    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2428        CertificateStatus {
2429            ocsp_response: self.ocsp_response.into_owned(),
2430        }
2431    }
2432}
2433
2434// -- RFC8879 compressed certificates
2435
2436#[derive(Debug)]
2437pub(crate) struct CompressedCertificatePayload<'a> {
2438    pub(crate) alg: CertificateCompressionAlgorithm,
2439    pub(crate) uncompressed_len: u32,
2440    pub(crate) compressed: PayloadU24<'a>,
2441}
2442
2443impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2444    fn encode(&self, bytes: &mut Vec<u8>) {
2445        self.alg.encode(bytes);
2446        codec::u24(self.uncompressed_len).encode(bytes);
2447        self.compressed.encode(bytes);
2448    }
2449
2450    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2451        Ok(Self {
2452            alg: CertificateCompressionAlgorithm::read(r)?,
2453            uncompressed_len: codec::u24::read(r)?.0,
2454            compressed: PayloadU24::read(r)?,
2455        })
2456    }
2457}
2458
2459impl CompressedCertificatePayload<'_> {
2460    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2461        CompressedCertificatePayload {
2462            compressed: self.compressed.into_owned(),
2463            ..self
2464        }
2465    }
2466
2467    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2468        CompressedCertificatePayload {
2469            alg: self.alg,
2470            uncompressed_len: self.uncompressed_len,
2471            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2472        }
2473    }
2474}
2475
2476#[derive(Debug)]
2477pub(crate) enum HandshakePayload<'a> {
2478    HelloRequest,
2479    ClientHello(ClientHelloPayload),
2480    ServerHello(ServerHelloPayload),
2481    HelloRetryRequest(HelloRetryRequest),
2482    Certificate(CertificateChain<'a>),
2483    CertificateTls13(CertificatePayloadTls13<'a>),
2484    CompressedCertificate(CompressedCertificatePayload<'a>),
2485    ServerKeyExchange(ServerKeyExchangePayload),
2486    CertificateRequest(CertificateRequestPayload),
2487    CertificateRequestTls13(CertificateRequestPayloadTls13),
2488    CertificateVerify(DigitallySignedStruct),
2489    ServerHelloDone,
2490    EndOfEarlyData,
2491    ClientKeyExchange(Payload<'a>),
2492    NewSessionTicket(NewSessionTicketPayload),
2493    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2494    EncryptedExtensions(Box<ServerExtensions<'a>>),
2495    KeyUpdate(KeyUpdateRequest),
2496    Finished(Payload<'a>),
2497    CertificateStatus(CertificateStatus<'a>),
2498    MessageHash(Payload<'a>),
2499    Unknown((HandshakeType, Payload<'a>)),
2500}
2501
2502impl HandshakePayload<'_> {
2503    fn encode(&self, bytes: &mut Vec<u8>) {
2504        use self::HandshakePayload::*;
2505        match self {
2506            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2507            ClientHello(x) => x.encode(bytes),
2508            ServerHello(x) => x.encode(bytes),
2509            HelloRetryRequest(x) => x.encode(bytes),
2510            Certificate(x) => x.encode(bytes),
2511            CertificateTls13(x) => x.encode(bytes),
2512            CompressedCertificate(x) => x.encode(bytes),
2513            ServerKeyExchange(x) => x.encode(bytes),
2514            ClientKeyExchange(x) => x.encode(bytes),
2515            CertificateRequest(x) => x.encode(bytes),
2516            CertificateRequestTls13(x) => x.encode(bytes),
2517            CertificateVerify(x) => x.encode(bytes),
2518            NewSessionTicket(x) => x.encode(bytes),
2519            NewSessionTicketTls13(x) => x.encode(bytes),
2520            EncryptedExtensions(x) => x.encode(bytes),
2521            KeyUpdate(x) => x.encode(bytes),
2522            Finished(x) => x.encode(bytes),
2523            CertificateStatus(x) => x.encode(bytes),
2524            MessageHash(x) => x.encode(bytes),
2525            Unknown((_, x)) => x.encode(bytes),
2526        }
2527    }
2528
2529    pub(crate) fn handshake_type(&self) -> HandshakeType {
2530        use self::HandshakePayload::*;
2531        match self {
2532            HelloRequest => HandshakeType::HelloRequest,
2533            ClientHello(_) => HandshakeType::ClientHello,
2534            ServerHello(_) => HandshakeType::ServerHello,
2535            HelloRetryRequest(_) => HandshakeType::HelloRetryRequest,
2536            Certificate(_) | CertificateTls13(_) => HandshakeType::Certificate,
2537            CompressedCertificate(_) => HandshakeType::CompressedCertificate,
2538            ServerKeyExchange(_) => HandshakeType::ServerKeyExchange,
2539            CertificateRequest(_) | CertificateRequestTls13(_) => HandshakeType::CertificateRequest,
2540            CertificateVerify(_) => HandshakeType::CertificateVerify,
2541            ServerHelloDone => HandshakeType::ServerHelloDone,
2542            EndOfEarlyData => HandshakeType::EndOfEarlyData,
2543            ClientKeyExchange(_) => HandshakeType::ClientKeyExchange,
2544            NewSessionTicket(_) | NewSessionTicketTls13(_) => HandshakeType::NewSessionTicket,
2545            EncryptedExtensions(_) => HandshakeType::EncryptedExtensions,
2546            KeyUpdate(_) => HandshakeType::KeyUpdate,
2547            Finished(_) => HandshakeType::Finished,
2548            CertificateStatus(_) => HandshakeType::CertificateStatus,
2549            MessageHash(_) => HandshakeType::MessageHash,
2550            Unknown((t, _)) => *t,
2551        }
2552    }
2553
2554    fn wire_handshake_type(&self) -> HandshakeType {
2555        match self.handshake_type() {
2556            // A `HelloRetryRequest` appears on the wire as a `ServerHello` with a magic `random` value.
2557            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2558            other => other,
2559        }
2560    }
2561
2562    fn into_owned(self) -> HandshakePayload<'static> {
2563        use HandshakePayload::*;
2564
2565        match self {
2566            HelloRequest => HelloRequest,
2567            ClientHello(x) => ClientHello(x),
2568            ServerHello(x) => ServerHello(x),
2569            HelloRetryRequest(x) => HelloRetryRequest(x),
2570            Certificate(x) => Certificate(x.into_owned()),
2571            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2572            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2573            ServerKeyExchange(x) => ServerKeyExchange(x),
2574            CertificateRequest(x) => CertificateRequest(x),
2575            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2576            CertificateVerify(x) => CertificateVerify(x),
2577            ServerHelloDone => ServerHelloDone,
2578            EndOfEarlyData => EndOfEarlyData,
2579            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2580            NewSessionTicket(x) => NewSessionTicket(x),
2581            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2582            EncryptedExtensions(x) => EncryptedExtensions(Box::new(x.into_owned())),
2583            KeyUpdate(x) => KeyUpdate(x),
2584            Finished(x) => Finished(x.into_owned()),
2585            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2586            MessageHash(x) => MessageHash(x.into_owned()),
2587            Unknown((t, x)) => Unknown((t, x.into_owned())),
2588        }
2589    }
2590}
2591
2592#[derive(Debug)]
2593pub struct HandshakeMessagePayload<'a>(pub(crate) HandshakePayload<'a>);
2594
2595impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2596    fn encode(&self, bytes: &mut Vec<u8>) {
2597        self.payload_encode(bytes, Encoding::Standard);
2598    }
2599
2600    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2601        Self::read_version(r, ProtocolVersion::TLSv1_2)
2602    }
2603}
2604
2605impl<'a> HandshakeMessagePayload<'a> {
2606    pub(crate) fn read_version(
2607        r: &mut Reader<'a>,
2608        vers: ProtocolVersion,
2609    ) -> Result<Self, InvalidMessage> {
2610        let typ = HandshakeType::read(r)?;
2611        let len = codec::u24::read(r)?.0 as usize;
2612        let mut sub = r.sub(len)?;
2613
2614        let payload = match typ {
2615            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2616            HandshakeType::ClientHello => {
2617                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2618            }
2619            HandshakeType::ServerHello => {
2620                let version = ProtocolVersion::read(&mut sub)?;
2621                let random = Random::read(&mut sub)?;
2622
2623                if random == HELLO_RETRY_REQUEST_RANDOM {
2624                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2625                    hrr.legacy_version = version;
2626                    HandshakePayload::HelloRetryRequest(hrr)
2627                } else {
2628                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2629                    shp.legacy_version = version;
2630                    shp.random = random;
2631                    HandshakePayload::ServerHello(shp)
2632                }
2633            }
2634            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2635                let p = CertificatePayloadTls13::read(&mut sub)?;
2636                HandshakePayload::CertificateTls13(p)
2637            }
2638            HandshakeType::Certificate => {
2639                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2640            }
2641            HandshakeType::ServerKeyExchange => {
2642                let p = ServerKeyExchangePayload::read(&mut sub)?;
2643                HandshakePayload::ServerKeyExchange(p)
2644            }
2645            HandshakeType::ServerHelloDone => {
2646                sub.expect_empty("ServerHelloDone")?;
2647                HandshakePayload::ServerHelloDone
2648            }
2649            HandshakeType::ClientKeyExchange => {
2650                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2651            }
2652            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2653                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2654                HandshakePayload::CertificateRequestTls13(p)
2655            }
2656            HandshakeType::CertificateRequest => {
2657                let p = CertificateRequestPayload::read(&mut sub)?;
2658                HandshakePayload::CertificateRequest(p)
2659            }
2660            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2661                CompressedCertificatePayload::read(&mut sub)?,
2662            ),
2663            HandshakeType::CertificateVerify => {
2664                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2665            }
2666            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2667                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2668                HandshakePayload::NewSessionTicketTls13(p)
2669            }
2670            HandshakeType::NewSessionTicket => {
2671                let p = NewSessionTicketPayload::read(&mut sub)?;
2672                HandshakePayload::NewSessionTicket(p)
2673            }
2674            HandshakeType::EncryptedExtensions => {
2675                HandshakePayload::EncryptedExtensions(Box::new(ServerExtensions::read(&mut sub)?))
2676            }
2677            HandshakeType::KeyUpdate => {
2678                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2679            }
2680            HandshakeType::EndOfEarlyData => {
2681                sub.expect_empty("EndOfEarlyData")?;
2682                HandshakePayload::EndOfEarlyData
2683            }
2684            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2685            HandshakeType::CertificateStatus => {
2686                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2687            }
2688            HandshakeType::MessageHash => {
2689                // does not appear on the wire
2690                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2691            }
2692            HandshakeType::HelloRetryRequest => {
2693                // not legal on wire
2694                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2695            }
2696            _ => HandshakePayload::Unknown((typ, Payload::read(&mut sub))),
2697        };
2698
2699        sub.expect_empty("HandshakeMessagePayload")
2700            .map(|_| Self(payload))
2701    }
2702
2703    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2704        let mut ret = self.get_encoding();
2705        let ret_len = ret.len() - self.total_binder_length();
2706        ret.truncate(ret_len);
2707        ret
2708    }
2709
2710    pub(crate) fn total_binder_length(&self) -> usize {
2711        match &self.0 {
2712            HandshakePayload::ClientHello(ch) => match &ch.preshared_key_offer {
2713                Some(offer) => {
2714                    let mut binders_encoding = Vec::new();
2715                    offer
2716                        .binders
2717                        .encode(&mut binders_encoding);
2718                    binders_encoding.len()
2719                }
2720                _ => 0,
2721            },
2722            _ => 0,
2723        }
2724    }
2725
2726    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2727        // output type, length, and encoded payload
2728        self.0
2729            .wire_handshake_type()
2730            .encode(bytes);
2731
2732        let nested = LengthPrefixedBuffer::new(
2733            ListLength::U24 {
2734                max: usize::MAX,
2735                error: InvalidMessage::MessageTooLarge,
2736            },
2737            bytes,
2738        );
2739
2740        match &self.0 {
2741            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2742            // differently based on the purpose of the encoding.
2743            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2744            HandshakePayload::HelloRetryRequest(payload) => {
2745                payload.payload_encode(nested.buf, encoding)
2746            }
2747
2748            // All other payload types are encoded the same regardless of purpose.
2749            _ => self.0.encode(nested.buf),
2750        }
2751    }
2752
2753    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2754        Self(HandshakePayload::MessageHash(Payload::new(hash.to_vec())))
2755    }
2756
2757    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2758        HandshakeMessagePayload(self.0.into_owned())
2759    }
2760}
2761
2762#[allow(clippy::exhaustive_structs)]
2763#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2764pub struct HpkeSymmetricCipherSuite {
2765    pub kdf_id: HpkeKdf,
2766    pub aead_id: HpkeAead,
2767}
2768
2769impl Codec<'_> for HpkeSymmetricCipherSuite {
2770    fn encode(&self, bytes: &mut Vec<u8>) {
2771        self.kdf_id.encode(bytes);
2772        self.aead_id.encode(bytes);
2773    }
2774
2775    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2776        Ok(Self {
2777            kdf_id: HpkeKdf::read(r)?,
2778            aead_id: HpkeAead::read(r)?,
2779        })
2780    }
2781}
2782
2783/// draft-ietf-tls-esni-24: `HpkeSymmetricCipherSuite cipher_suites<4..2^16-4>;`
2784impl TlsListElement for HpkeSymmetricCipherSuite {
2785    const SIZE_LEN: ListLength = ListLength::NonZeroU16 {
2786        empty_error: InvalidMessage::IllegalEmptyList("HpkeSymmetricCipherSuites"),
2787    };
2788}
2789
2790#[allow(clippy::exhaustive_structs)]
2791#[derive(Clone, Debug, PartialEq)]
2792pub(crate) struct HpkeKeyConfig {
2793    pub config_id: u8,
2794    pub kem_id: HpkeKem,
2795    /// draft-ietf-tls-esni-24: `opaque HpkePublicKey<1..2^16-1>;`
2796    pub public_key: PayloadU16<NonEmpty>,
2797    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2798}
2799
2800impl Codec<'_> for HpkeKeyConfig {
2801    fn encode(&self, bytes: &mut Vec<u8>) {
2802        self.config_id.encode(bytes);
2803        self.kem_id.encode(bytes);
2804        self.public_key.encode(bytes);
2805        self.symmetric_cipher_suites
2806            .encode(bytes);
2807    }
2808
2809    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2810        Ok(Self {
2811            config_id: u8::read(r)?,
2812            kem_id: HpkeKem::read(r)?,
2813            public_key: PayloadU16::read(r)?,
2814            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2815        })
2816    }
2817}
2818
2819#[allow(clippy::exhaustive_structs)]
2820#[derive(Clone, Debug, PartialEq)]
2821pub(crate) struct EchConfigContents {
2822    pub key_config: HpkeKeyConfig,
2823    pub maximum_name_length: u8,
2824    pub public_name: DnsName<'static>,
2825    pub extensions: Vec<EchConfigExtension>,
2826}
2827
2828impl EchConfigContents {
2829    /// Returns true if there is more than one extension of a given
2830    /// type.
2831    pub(crate) fn has_duplicate_extension(&self) -> bool {
2832        has_duplicates::<_, _, u16>(
2833            self.extensions
2834                .iter()
2835                .map(|ext| ext.ext_type()),
2836        )
2837    }
2838
2839    /// Returns true if there is at least one mandatory unsupported extension.
2840    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2841        self.extensions
2842            .iter()
2843            // An extension is considered mandatory if the high bit of its type is set.
2844            .any(|ext| {
2845                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2846                    && u16::from(ext.ext_type()) & 0x8000 != 0
2847            })
2848    }
2849}
2850
2851impl Codec<'_> for EchConfigContents {
2852    fn encode(&self, bytes: &mut Vec<u8>) {
2853        self.key_config.encode(bytes);
2854        self.maximum_name_length.encode(bytes);
2855        let dns_name = &self.public_name.borrow();
2856        PayloadU8::<MaybeEmpty>::encode_slice(dns_name.as_ref().as_ref(), bytes);
2857        self.extensions.encode(bytes);
2858    }
2859
2860    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2861        Ok(Self {
2862            key_config: HpkeKeyConfig::read(r)?,
2863            maximum_name_length: u8::read(r)?,
2864            public_name: {
2865                DnsName::try_from(
2866                    PayloadU8::<MaybeEmpty>::read(r)?
2867                        .0
2868                        .as_slice(),
2869                )
2870                .map_err(|_| InvalidMessage::InvalidServerName)?
2871                .to_owned()
2872            },
2873            extensions: Vec::read(r)?,
2874        })
2875    }
2876}
2877
2878/// An encrypted client hello (ECH) config.
2879#[non_exhaustive]
2880#[derive(Clone, Debug, PartialEq)]
2881pub(crate) enum EchConfigPayload {
2882    /// A recognized V18 ECH configuration.
2883    V18(EchConfigContents),
2884    /// An unknown version ECH configuration.
2885    Unknown {
2886        version: EchVersion,
2887        contents: PayloadU16,
2888    },
2889}
2890
2891impl TlsListElement for EchConfigPayload {
2892    const SIZE_LEN: ListLength = ListLength::U16;
2893}
2894
2895impl Codec<'_> for EchConfigPayload {
2896    fn encode(&self, bytes: &mut Vec<u8>) {
2897        match self {
2898            Self::V18(c) => {
2899                // Write the version, the length, and the contents.
2900                EchVersion::V18.encode(bytes);
2901                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2902                c.encode(inner.buf);
2903            }
2904            Self::Unknown { version, contents } => {
2905                // Unknown configuration versions are opaque.
2906                version.encode(bytes);
2907                contents.encode(bytes);
2908            }
2909        }
2910    }
2911
2912    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2913        let version = EchVersion::read(r)?;
2914        let length = u16::read(r)?;
2915        let mut contents = r.sub(length as usize)?;
2916
2917        Ok(match version {
2918            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2919            _ => {
2920                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2921                let data = PayloadU16::new(contents.rest().into());
2922                Self::Unknown {
2923                    version,
2924                    contents: data,
2925                }
2926            }
2927        })
2928    }
2929}
2930
2931#[derive(Clone, Debug, PartialEq)]
2932pub(crate) enum EchConfigExtension {
2933    Unknown(UnknownExtension),
2934}
2935
2936impl EchConfigExtension {
2937    pub(crate) fn ext_type(&self) -> ExtensionType {
2938        match self {
2939            Self::Unknown(r) => r.typ,
2940        }
2941    }
2942}
2943
2944impl Codec<'_> for EchConfigExtension {
2945    fn encode(&self, bytes: &mut Vec<u8>) {
2946        self.ext_type().encode(bytes);
2947
2948        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2949        match self {
2950            Self::Unknown(r) => r.encode(nested.buf),
2951        }
2952    }
2953
2954    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2955        let typ = ExtensionType::read(r)?;
2956        let len = u16::read(r)? as usize;
2957        let mut sub = r.sub(len)?;
2958
2959        #[allow(clippy::match_single_binding)] // Future-proofing.
2960        let ext = match typ {
2961            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2962        };
2963
2964        sub.expect_empty("EchConfigExtension")
2965            .map(|_| ext)
2966    }
2967}
2968
2969impl TlsListElement for EchConfigExtension {
2970    const SIZE_LEN: ListLength = ListLength::U16;
2971}
2972
2973/// Representation of the `ECHClientHello` client extension specified in
2974/// [draft-ietf-tls-esni Section 5].
2975///
2976/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
2977#[derive(Clone, Debug)]
2978pub(crate) enum EncryptedClientHello {
2979    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
2980    Outer(EncryptedClientHelloOuter),
2981    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
2982    ///
2983    /// This variant has no payload.
2984    Inner,
2985}
2986
2987impl Codec<'_> for EncryptedClientHello {
2988    fn encode(&self, bytes: &mut Vec<u8>) {
2989        match self {
2990            Self::Outer(payload) => {
2991                EchClientHelloType::ClientHelloOuter.encode(bytes);
2992                payload.encode(bytes);
2993            }
2994            Self::Inner => {
2995                EchClientHelloType::ClientHelloInner.encode(bytes);
2996                // Empty payload.
2997            }
2998        }
2999    }
3000
3001    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3002        match EchClientHelloType::read(r)? {
3003            EchClientHelloType::ClientHelloOuter => {
3004                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3005            }
3006            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3007            _ => Err(InvalidMessage::InvalidContentType),
3008        }
3009    }
3010}
3011
3012/// Representation of the ECHClientHello extension with type outer specified in
3013/// [draft-ietf-tls-esni Section 5].
3014///
3015/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3016#[derive(Clone, Debug)]
3017pub(crate) struct EncryptedClientHelloOuter {
3018    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3019    /// ECHConfigContents.cipher_suites list.
3020    pub cipher_suite: HpkeSymmetricCipherSuite,
3021    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3022    pub config_id: u8,
3023    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3024    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3025    pub enc: PayloadU16,
3026    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3027    pub payload: PayloadU16<NonEmpty>,
3028}
3029
3030impl Codec<'_> for EncryptedClientHelloOuter {
3031    fn encode(&self, bytes: &mut Vec<u8>) {
3032        self.cipher_suite.encode(bytes);
3033        self.config_id.encode(bytes);
3034        self.enc.encode(bytes);
3035        self.payload.encode(bytes);
3036    }
3037
3038    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3039        Ok(Self {
3040            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3041            config_id: u8::read(r)?,
3042            enc: PayloadU16::read(r)?,
3043            payload: PayloadU16::read(r)?,
3044        })
3045    }
3046}
3047
3048/// Representation of the ECHEncryptedExtensions extension specified in
3049/// [draft-ietf-tls-esni Section 5].
3050///
3051/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3052#[derive(Clone, Debug)]
3053pub(crate) struct ServerEncryptedClientHello {
3054    pub(crate) retry_configs: Vec<EchConfigPayload>,
3055}
3056
3057impl Codec<'_> for ServerEncryptedClientHello {
3058    fn encode(&self, bytes: &mut Vec<u8>) {
3059        self.retry_configs.encode(bytes);
3060    }
3061
3062    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3063        Ok(Self {
3064            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3065        })
3066    }
3067}
3068
3069/// The method of encoding to use for a handshake message.
3070///
3071/// In some cases a handshake message may be encoded differently depending on the purpose
3072/// the encoded message is being used for.
3073pub(crate) enum Encoding {
3074    /// Standard RFC 8446 encoding.
3075    Standard,
3076    /// Encoding for ECH confirmation for HRR.
3077    EchConfirmation,
3078    /// Encoding for ECH inner client hello.
3079    EchInnerHello { to_compress: Vec<ExtensionType> },
3080}
3081
3082fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3083    let mut seen = BTreeSet::new();
3084
3085    for x in iter {
3086        if !seen.insert(x.into()) {
3087            return true;
3088        }
3089    }
3090
3091    false
3092}
3093
3094struct DuplicateExtensionChecker(BTreeSet<u16>);
3095
3096impl DuplicateExtensionChecker {
3097    fn new() -> Self {
3098        Self(BTreeSet::new())
3099    }
3100
3101    fn check(&mut self, typ: ExtensionType) -> Result<(), InvalidMessage> {
3102        let u = u16::from(typ);
3103        match self.0.insert(u) {
3104            true => Ok(()),
3105            false => Err(InvalidMessage::DuplicateExtension(u)),
3106        }
3107    }
3108}
3109
3110fn low_quality_integer_hash(mut x: u32) -> u32 {
3111    x = x
3112        .wrapping_add(0x7ed55d16)
3113        .wrapping_add(x << 12);
3114    x = (x ^ 0xc761c23c) ^ (x >> 19);
3115    x = x
3116        .wrapping_add(0x165667b1)
3117        .wrapping_add(x << 5);
3118    x = x.wrapping_add(0xd3a2646c) ^ (x << 9);
3119    x = x
3120        .wrapping_add(0xfd7046c5)
3121        .wrapping_add(x << 3);
3122    x = (x ^ 0xb55a4f09) ^ (x >> 16);
3123    x
3124}
3125
3126#[cfg(test)]
3127mod tests {
3128    use super::*;
3129
3130    #[test]
3131    fn test_ech_config_dupe_exts() {
3132        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3133            typ: ExtensionType::Unknown(0x42),
3134            payload: Payload::new(vec![0x42]),
3135        });
3136        let mut config = config_template();
3137        config
3138            .extensions
3139            .push(unknown_ext.clone());
3140        config.extensions.push(unknown_ext);
3141
3142        assert!(config.has_duplicate_extension());
3143        assert!(!config.has_unknown_mandatory_extension());
3144    }
3145
3146    #[test]
3147    fn test_ech_config_mandatory_exts() {
3148        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3149            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3150            payload: Payload::new(vec![0x42]),
3151        });
3152        let mut config = config_template();
3153        config
3154            .extensions
3155            .push(mandatory_unknown_ext);
3156
3157        assert!(!config.has_duplicate_extension());
3158        assert!(config.has_unknown_mandatory_extension());
3159    }
3160
3161    fn config_template() -> EchConfigContents {
3162        EchConfigContents {
3163            key_config: HpkeKeyConfig {
3164                config_id: 0,
3165                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3166                public_key: PayloadU16::new(b"xxx".into()),
3167                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3168                    kdf_id: HpkeKdf::HKDF_SHA256,
3169                    aead_id: HpkeAead::AES_128_GCM,
3170                }],
3171            },
3172            maximum_name_length: 0,
3173            public_name: DnsName::try_from("example.com").unwrap(),
3174            extensions: vec![],
3175        }
3176    }
3177}