rustls/msgs/
persist.rs

1use alloc::vec::Vec;
2use core::cmp;
3
4use pki_types::{DnsName, UnixTime};
5use zeroize::Zeroizing;
6
7use crate::client::ClientCredentialResolver;
8use crate::crypto::Identity;
9use crate::enums::{CipherSuite, ProtocolVersion};
10use crate::error::InvalidMessage;
11use crate::msgs::base::{MaybeEmpty, PayloadU8, PayloadU16};
12use crate::msgs::codec::{Codec, Reader};
13use crate::msgs::handshake::{ProtocolName, SessionId};
14use crate::sync::{Arc, Weak};
15use crate::tls12::Tls12CipherSuite;
16use crate::tls13::Tls13CipherSuite;
17use crate::verify::ServerVerifier;
18
19pub(crate) struct Retrieved<T> {
20    pub(crate) value: T,
21    retrieved_at: UnixTime,
22}
23
24impl<T> Retrieved<T> {
25    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
26        Self {
27            value,
28            retrieved_at,
29        }
30    }
31
32    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
33        Some(Retrieved {
34            value: f(&self.value)?,
35            retrieved_at: self.retrieved_at,
36        })
37    }
38}
39
40impl Retrieved<&Tls13ClientSessionValue> {
41    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
42        let age_secs = self
43            .retrieved_at
44            .as_secs()
45            .saturating_sub(self.value.common.epoch);
46        let age_millis = age_secs as u32 * 1000;
47        age_millis.wrapping_add(self.value.age_add)
48    }
49}
50
51impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
52    pub(crate) fn has_expired(&self) -> bool {
53        let common = &*self.value;
54        common.lifetime_secs != 0
55            && common
56                .epoch
57                .saturating_add(u64::from(common.lifetime_secs))
58                < self.retrieved_at.as_secs()
59    }
60}
61
62impl<T> core::ops::Deref for Retrieved<T> {
63    type Target = T;
64
65    fn deref(&self) -> &Self::Target {
66        &self.value
67    }
68}
69
70#[derive(Debug)]
71pub struct Tls13ClientSessionValue {
72    suite: &'static Tls13CipherSuite,
73    secret: Zeroizing<PayloadU8>,
74    age_add: u32,
75    max_early_data_size: u32,
76    pub(crate) common: ClientSessionCommon,
77    quic_params: PayloadU16,
78}
79
80impl Tls13ClientSessionValue {
81    pub(crate) fn new(
82        suite: &'static Tls13CipherSuite,
83        ticket: Arc<PayloadU16>,
84        secret: &[u8],
85        peer_identity: Identity<'static>,
86        server_cert_verifier: &Arc<dyn ServerVerifier>,
87        client_creds: &Arc<dyn ClientCredentialResolver>,
88        time_now: UnixTime,
89        lifetime_secs: u32,
90        age_add: u32,
91        max_early_data_size: u32,
92    ) -> Self {
93        Self {
94            suite,
95            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
96            age_add,
97            max_early_data_size,
98            common: ClientSessionCommon::new(
99                ticket,
100                time_now,
101                lifetime_secs,
102                peer_identity,
103                server_cert_verifier,
104                client_creds,
105            ),
106            quic_params: PayloadU16::new(Vec::new()),
107        }
108    }
109
110    pub(crate) fn secret(&self) -> &[u8] {
111        self.secret.0.as_ref()
112    }
113
114    pub fn max_early_data_size(&self) -> u32 {
115        self.max_early_data_size
116    }
117
118    pub fn suite(&self) -> &'static Tls13CipherSuite {
119        self.suite
120    }
121
122    /// Test only: rewind epoch by `delta` seconds.
123    #[doc(hidden)]
124    pub fn rewind_epoch(&mut self, delta: u32) {
125        self.common.epoch -= delta as u64;
126    }
127
128    /// Test only: replace `max_early_data_size` with `new`
129    #[doc(hidden)]
130    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
131        self.max_early_data_size = new;
132    }
133
134    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
135        self.quic_params = PayloadU16::new(quic_params.to_vec());
136    }
137
138    pub fn quic_params(&self) -> Vec<u8> {
139        self.quic_params.0.clone()
140    }
141}
142
143impl core::ops::Deref for Tls13ClientSessionValue {
144    type Target = ClientSessionCommon;
145
146    fn deref(&self) -> &Self::Target {
147        &self.common
148    }
149}
150
151#[derive(Debug, Clone)]
152pub struct Tls12ClientSessionValue {
153    suite: &'static Tls12CipherSuite,
154    pub(crate) session_id: SessionId,
155    master_secret: Zeroizing<[u8; 48]>,
156    extended_ms: bool,
157    #[doc(hidden)]
158    pub(crate) common: ClientSessionCommon,
159}
160
161impl Tls12ClientSessionValue {
162    pub(crate) fn new(
163        suite: &'static Tls12CipherSuite,
164        session_id: SessionId,
165        ticket: Arc<PayloadU16>,
166        master_secret: &[u8; 48],
167        peer_identity: Identity<'static>,
168        server_cert_verifier: &Arc<dyn ServerVerifier>,
169        client_creds: &Arc<dyn ClientCredentialResolver>,
170        time_now: UnixTime,
171        lifetime_secs: u32,
172        extended_ms: bool,
173    ) -> Self {
174        Self {
175            suite,
176            session_id,
177            master_secret: Zeroizing::new(*master_secret),
178            extended_ms,
179            common: ClientSessionCommon::new(
180                ticket,
181                time_now,
182                lifetime_secs,
183                peer_identity,
184                server_cert_verifier,
185                client_creds,
186            ),
187        }
188    }
189
190    pub(crate) fn master_secret(&self) -> &[u8; 48] {
191        &self.master_secret
192    }
193
194    pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
195        self.common.ticket.clone()
196    }
197
198    pub(crate) fn extended_ms(&self) -> bool {
199        self.extended_ms
200    }
201
202    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
203        self.suite
204    }
205
206    /// Test only: rewind epoch by `delta` seconds.
207    #[doc(hidden)]
208    pub fn rewind_epoch(&mut self, delta: u32) {
209        self.common.epoch -= delta as u64;
210    }
211}
212
213impl core::ops::Deref for Tls12ClientSessionValue {
214    type Target = ClientSessionCommon;
215
216    fn deref(&self) -> &Self::Target {
217        &self.common
218    }
219}
220
221#[derive(Debug, Clone)]
222pub struct ClientSessionCommon {
223    ticket: Arc<PayloadU16>,
224    epoch: u64,
225    lifetime_secs: u32,
226    peer_identity: Arc<Identity<'static>>,
227    server_cert_verifier: Weak<dyn ServerVerifier>,
228    client_creds: Weak<dyn ClientCredentialResolver>,
229}
230
231impl ClientSessionCommon {
232    fn new(
233        ticket: Arc<PayloadU16>,
234        time_now: UnixTime,
235        lifetime_secs: u32,
236        peer_identity: Identity<'static>,
237        server_cert_verifier: &Arc<dyn ServerVerifier>,
238        client_creds: &Arc<dyn ClientCredentialResolver>,
239    ) -> Self {
240        Self {
241            ticket,
242            epoch: time_now.as_secs(),
243            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
244            peer_identity: Arc::new(peer_identity),
245            server_cert_verifier: Arc::downgrade(server_cert_verifier),
246            client_creds: Arc::downgrade(client_creds),
247        }
248    }
249
250    pub(crate) fn compatible_config(
251        &self,
252        server_cert_verifier: &Arc<dyn ServerVerifier>,
253        client_creds: &Arc<dyn ClientCredentialResolver>,
254    ) -> bool {
255        let same_verifier = Weak::ptr_eq(
256            &Arc::downgrade(server_cert_verifier),
257            &self.server_cert_verifier,
258        );
259        let same_creds = Weak::ptr_eq(&Arc::downgrade(client_creds), &self.client_creds);
260
261        match (same_verifier, same_creds) {
262            (true, true) => true,
263            (false, _) => {
264                crate::log::trace!("resumption not allowed between different ServerVerifiers");
265                false
266            }
267            (_, _) => {
268                crate::log::trace!(
269                    "resumption not allowed between different ClientCredentialResolver values"
270                );
271                false
272            }
273        }
274    }
275
276    pub(crate) fn peer_identity(&self) -> &Identity<'static> {
277        &self.peer_identity
278    }
279
280    pub(crate) fn ticket(&self) -> &[u8] {
281        self.ticket.0.as_ref()
282    }
283}
284
285static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
286
287/// This is the maximum allowed skew between server and client clocks, over
288/// the maximum ticket lifetime period.  This encompasses TCP retransmission
289/// times in case packet loss occurs when the client sends the ClientHello
290/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
291static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
292
293// --- Server types ---
294#[non_exhaustive]
295#[derive(Debug)]
296pub enum ServerSessionValue {
297    Tls12(Tls12ServerSessionValue),
298    Tls13(Tls13ServerSessionValue),
299}
300
301impl Codec<'_> for ServerSessionValue {
302    fn encode(&self, bytes: &mut Vec<u8>) {
303        match self {
304            Self::Tls12(value) => {
305                ProtocolVersion::TLSv1_2.encode(bytes);
306                value.encode(bytes);
307            }
308            Self::Tls13(value) => {
309                ProtocolVersion::TLSv1_3.encode(bytes);
310                value.encode(bytes);
311            }
312        }
313    }
314
315    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
316        match ProtocolVersion::read(r)? {
317            ProtocolVersion::TLSv1_2 => Ok(Self::Tls12(Tls12ServerSessionValue::read(r)?)),
318            ProtocolVersion::TLSv1_3 => Ok(Self::Tls13(Tls13ServerSessionValue::read(r)?)),
319            _ => Err(InvalidMessage::UnknownProtocolVersion),
320        }
321    }
322}
323
324#[derive(Debug)]
325pub struct Tls12ServerSessionValue {
326    #[doc(hidden)]
327    pub common: CommonServerSessionValue,
328    pub(crate) master_secret: Zeroizing<[u8; 48]>,
329    pub(crate) extended_ms: bool,
330}
331
332impl Tls12ServerSessionValue {
333    pub(crate) fn new(
334        common: CommonServerSessionValue,
335        master_secret: &[u8; 48],
336        extended_ms: bool,
337    ) -> Self {
338        Self {
339            common,
340            master_secret: Zeroizing::new(*master_secret),
341            extended_ms,
342        }
343    }
344}
345
346impl Codec<'_> for Tls12ServerSessionValue {
347    fn encode(&self, bytes: &mut Vec<u8>) {
348        self.common.encode(bytes);
349        bytes.extend_from_slice(self.master_secret.as_ref());
350        (self.extended_ms as u8).encode(bytes);
351    }
352
353    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
354        Ok(Self {
355            common: CommonServerSessionValue::read(r)?,
356            master_secret: Zeroizing::new(
357                match r
358                    .take(48)
359                    .and_then(|slice| slice.try_into().ok())
360                {
361                    Some(array) => array,
362                    None => return Err(InvalidMessage::MessageTooShort),
363                },
364            ),
365            extended_ms: matches!(u8::read(r)?, 1),
366        })
367    }
368}
369
370impl From<Tls12ServerSessionValue> for ServerSessionValue {
371    fn from(value: Tls12ServerSessionValue) -> Self {
372        Self::Tls12(value)
373    }
374}
375
376#[derive(Debug)]
377pub struct Tls13ServerSessionValue {
378    #[doc(hidden)]
379    pub common: CommonServerSessionValue,
380    pub(crate) secret: Zeroizing<PayloadU8>,
381    pub(crate) age_obfuscation_offset: u32,
382
383    // not encoded vv
384    freshness: Option<bool>,
385}
386
387impl Tls13ServerSessionValue {
388    pub(crate) fn new(
389        common: CommonServerSessionValue,
390        secret: &[u8],
391        age_obfuscation_offset: u32,
392    ) -> Self {
393        Self {
394            common,
395            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
396            age_obfuscation_offset,
397            freshness: None,
398        }
399    }
400
401    pub(crate) fn set_freshness(
402        mut self,
403        obfuscated_client_age_ms: u32,
404        time_now: UnixTime,
405    ) -> Self {
406        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
407        let server_age_ms = (time_now
408            .as_secs()
409            .saturating_sub(self.common.creation_time_sec) as u32)
410            .saturating_mul(1000);
411
412        let age_difference = server_age_ms.abs_diff(client_age_ms);
413
414        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
415        self
416    }
417
418    pub(crate) fn is_fresh(&self) -> bool {
419        self.freshness.unwrap_or_default()
420    }
421}
422
423impl Codec<'_> for Tls13ServerSessionValue {
424    fn encode(&self, bytes: &mut Vec<u8>) {
425        self.common.encode(bytes);
426        self.secret.encode(bytes);
427        self.age_obfuscation_offset
428            .encode(bytes);
429    }
430
431    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
432        Ok(Self {
433            common: CommonServerSessionValue::read(r)?,
434            secret: Zeroizing::new(PayloadU8::read(r)?),
435            age_obfuscation_offset: u32::read(r)?,
436            freshness: None,
437        })
438    }
439}
440
441impl From<Tls13ServerSessionValue> for ServerSessionValue {
442    fn from(value: Tls13ServerSessionValue) -> Self {
443        Self::Tls13(value)
444    }
445}
446
447#[derive(Debug)]
448pub struct CommonServerSessionValue {
449    pub(crate) sni: Option<DnsName<'static>>,
450    pub(crate) cipher_suite: CipherSuite,
451    pub(crate) peer_identity: Option<Identity<'static>>,
452    pub(crate) alpn: Option<ProtocolName>,
453    pub(crate) application_data: PayloadU16,
454    #[doc(hidden)]
455    pub creation_time_sec: u64,
456}
457
458impl CommonServerSessionValue {
459    pub(crate) fn new(
460        sni: Option<&DnsName<'_>>,
461        cipher_suite: CipherSuite,
462        peer_identity: Option<Identity<'static>>,
463        alpn: Option<ProtocolName>,
464        application_data: Vec<u8>,
465        creation_time: UnixTime,
466    ) -> Self {
467        Self {
468            sni: sni.map(|s| s.to_owned()),
469            cipher_suite,
470            peer_identity,
471            alpn,
472            application_data: PayloadU16::new(application_data),
473            creation_time_sec: creation_time.as_secs(),
474        }
475    }
476
477    pub(crate) fn can_resume(&self, suite: CipherSuite, sni: &Option<DnsName<'_>>) -> bool {
478        // The RFCs underspecify what happens if we try to resume to
479        // an unoffered/varying suite.  We merely don't resume in weird cases.
480        //
481        // RFC 6066 says "A server that implements this extension MUST NOT accept
482        // the request to resume the session if the server_name extension contains
483        // a different name. Instead, it proceeds with a full handshake to
484        // establish a new session."
485        //
486        // RFC 8446: "The server MUST ensure that it selects
487        // a compatible PSK (if any) and cipher suite."
488        self.cipher_suite == suite && &self.sni == sni
489    }
490}
491
492impl Codec<'_> for CommonServerSessionValue {
493    fn encode(&self, bytes: &mut Vec<u8>) {
494        if let Some(sni) = &self.sni {
495            1u8.encode(bytes);
496            let sni_bytes: &str = sni.as_ref();
497            PayloadU8::<MaybeEmpty>::encode_slice(sni_bytes.as_bytes(), bytes);
498        } else {
499            0u8.encode(bytes);
500        }
501        self.cipher_suite.encode(bytes);
502        if let Some(identity) = &self.peer_identity {
503            1u8.encode(bytes);
504            identity.encode(bytes);
505        } else {
506            0u8.encode(bytes);
507        }
508        if let Some(alpn) = &self.alpn {
509            1u8.encode(bytes);
510            alpn.encode(bytes);
511        } else {
512            0u8.encode(bytes);
513        }
514        self.application_data.encode(bytes);
515        self.creation_time_sec.encode(bytes);
516    }
517
518    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
519        let sni = match u8::read(r)? {
520            1 => {
521                let dns_name = PayloadU8::<MaybeEmpty>::read(r)?;
522                let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
523                    Ok(dns_name) => dns_name.to_owned(),
524                    Err(_) => return Err(InvalidMessage::InvalidServerName),
525                };
526
527                Some(dns_name)
528            }
529            _ => None,
530        };
531
532        Ok(Self {
533            sni,
534            cipher_suite: CipherSuite::read(r)?,
535            peer_identity: match u8::read(r)? {
536                1 => Some(Identity::read(r)?.into_owned()),
537                _ => None,
538            },
539            alpn: match u8::read(r)? {
540                1 => Some(ProtocolName::read(r)?),
541                _ => None,
542            },
543            application_data: PayloadU16::read(r)?,
544            creation_time_sec: u64::read(r)?,
545        })
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use pki_types::CertificateDer;
552
553    use super::*;
554    use crate::crypto::CertificateIdentity;
555
556    #[cfg(feature = "std")] // for UnixTime::now
557    #[test]
558    fn serversessionvalue_is_debug() {
559        use std::{println, vec};
560        let ssv = ServerSessionValue::Tls13(Tls13ServerSessionValue::new(
561            CommonServerSessionValue::new(
562                None,
563                CipherSuite::TLS13_AES_128_GCM_SHA256,
564                None,
565                None,
566                vec![4, 5, 6],
567                UnixTime::now(),
568            ),
569            &[1, 2, 3],
570            0x12345678,
571        ));
572        println!("{ssv:?}");
573        println!("{:#04x?}", ssv.get_encoding());
574    }
575
576    #[test]
577    fn serversessionvalue_no_sni() {
578        let bytes = [
579            0x03, 0x04, 0x00, 0x13, 0x01, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05, 0x06, 0x00, 0x00,
580            0x00, 0x00, 0x68, 0x6e, 0x94, 0x32, 0x03, 0x01, 0x02, 0x03, 0x12, 0x34, 0x56, 0x78,
581        ];
582        let mut rd = Reader::init(&bytes);
583        let ssv = ServerSessionValue::read(&mut rd).unwrap();
584        assert_eq!(ssv.get_encoding(), bytes);
585    }
586
587    #[test]
588    fn serversessionvalue_with_cert() {
589        std::eprintln!(
590            "{:#04x?}",
591            ServerSessionValue::Tls13(Tls13ServerSessionValue::new(
592                CommonServerSessionValue::new(
593                    None,
594                    CipherSuite::TLS13_AES_128_GCM_SHA256,
595                    Some(Identity::X509(CertificateIdentity {
596                        end_entity: CertificateDer::from(&[10, 11, 12][..]),
597                        intermediates: alloc::vec![],
598                    })),
599                    None,
600                    alloc::vec![4, 5, 6],
601                    UnixTime::now(),
602                ),
603                &[1, 2, 3],
604                0x12345678,
605            ))
606            .get_encoding()
607        );
608
609        let bytes = [
610            0x03, 0x04, 0x00, 0x13, 0x01, 0x01, 0x00, 0x00, 0x00, 0x03, 0x0a, 0x0b, 0x0c, 0x00,
611            0x00, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x68, 0xc1,
612            0x99, 0xac, 0x03, 0x01, 0x02, 0x03, 0x12, 0x34, 0x56, 0x78,
613        ];
614        let mut rd = Reader::init(&bytes);
615        let ssv = ServerSessionValue::read(&mut rd).unwrap();
616        assert_eq!(ssv.get_encoding(), bytes);
617    }
618}