rustls/msgs/
persist.rs

1use alloc::vec::Vec;
2use core::cmp;
3
4use pki_types::{DnsName, UnixTime};
5use zeroize::Zeroizing;
6
7use crate::client::ResolvesClientCert;
8use crate::enums::{CipherSuite, ProtocolVersion};
9use crate::error::InvalidMessage;
10use crate::msgs::base::{MaybeEmpty, PayloadU8, PayloadU16};
11use crate::msgs::codec::{Codec, Reader};
12use crate::msgs::handshake::{CertificateChain, ProtocolName, SessionId};
13use crate::sync::{Arc, Weak};
14use crate::tls12::Tls12CipherSuite;
15use crate::tls13::Tls13CipherSuite;
16use crate::verify::ServerCertVerifier;
17
18pub(crate) struct Retrieved<T> {
19    pub(crate) value: T,
20    retrieved_at: UnixTime,
21}
22
23impl<T> Retrieved<T> {
24    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
25        Self {
26            value,
27            retrieved_at,
28        }
29    }
30
31    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
32        Some(Retrieved {
33            value: f(&self.value)?,
34            retrieved_at: self.retrieved_at,
35        })
36    }
37}
38
39impl Retrieved<&Tls13ClientSessionValue> {
40    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
41        let age_secs = self
42            .retrieved_at
43            .as_secs()
44            .saturating_sub(self.value.common.epoch);
45        let age_millis = age_secs as u32 * 1000;
46        age_millis.wrapping_add(self.value.age_add)
47    }
48}
49
50impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
51    pub(crate) fn has_expired(&self) -> bool {
52        let common = &*self.value;
53        common.lifetime_secs != 0
54            && common
55                .epoch
56                .saturating_add(u64::from(common.lifetime_secs))
57                < self.retrieved_at.as_secs()
58    }
59}
60
61impl<T> core::ops::Deref for Retrieved<T> {
62    type Target = T;
63
64    fn deref(&self) -> &Self::Target {
65        &self.value
66    }
67}
68
69#[derive(Debug)]
70pub struct Tls13ClientSessionValue {
71    suite: &'static Tls13CipherSuite,
72    secret: Zeroizing<PayloadU8>,
73    age_add: u32,
74    max_early_data_size: u32,
75    pub(crate) common: ClientSessionCommon,
76    quic_params: PayloadU16,
77}
78
79impl Tls13ClientSessionValue {
80    pub(crate) fn new(
81        suite: &'static Tls13CipherSuite,
82        ticket: Arc<PayloadU16>,
83        secret: &[u8],
84        server_cert_chain: CertificateChain<'static>,
85        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
86        client_creds: &Arc<dyn ResolvesClientCert>,
87        time_now: UnixTime,
88        lifetime_secs: u32,
89        age_add: u32,
90        max_early_data_size: u32,
91    ) -> Self {
92        Self {
93            suite,
94            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
95            age_add,
96            max_early_data_size,
97            common: ClientSessionCommon::new(
98                ticket,
99                time_now,
100                lifetime_secs,
101                server_cert_chain,
102                server_cert_verifier,
103                client_creds,
104            ),
105            quic_params: PayloadU16::new(Vec::new()),
106        }
107    }
108
109    pub(crate) fn secret(&self) -> &[u8] {
110        self.secret.0.as_ref()
111    }
112
113    pub fn max_early_data_size(&self) -> u32 {
114        self.max_early_data_size
115    }
116
117    pub fn suite(&self) -> &'static Tls13CipherSuite {
118        self.suite
119    }
120
121    /// Test only: rewind epoch by `delta` seconds.
122    #[doc(hidden)]
123    pub fn rewind_epoch(&mut self, delta: u32) {
124        self.common.epoch -= delta as u64;
125    }
126
127    /// Test only: replace `max_early_data_size` with `new`
128    #[doc(hidden)]
129    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
130        self.max_early_data_size = new;
131    }
132
133    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
134        self.quic_params = PayloadU16::new(quic_params.to_vec());
135    }
136
137    pub fn quic_params(&self) -> Vec<u8> {
138        self.quic_params.0.clone()
139    }
140}
141
142impl core::ops::Deref for Tls13ClientSessionValue {
143    type Target = ClientSessionCommon;
144
145    fn deref(&self) -> &Self::Target {
146        &self.common
147    }
148}
149
150#[derive(Debug, Clone)]
151pub struct Tls12ClientSessionValue {
152    suite: &'static Tls12CipherSuite,
153    pub(crate) session_id: SessionId,
154    master_secret: Zeroizing<[u8; 48]>,
155    extended_ms: bool,
156    #[doc(hidden)]
157    pub(crate) common: ClientSessionCommon,
158}
159
160impl Tls12ClientSessionValue {
161    pub(crate) fn new(
162        suite: &'static Tls12CipherSuite,
163        session_id: SessionId,
164        ticket: Arc<PayloadU16>,
165        master_secret: &[u8; 48],
166        server_cert_chain: CertificateChain<'static>,
167        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
168        client_creds: &Arc<dyn ResolvesClientCert>,
169        time_now: UnixTime,
170        lifetime_secs: u32,
171        extended_ms: bool,
172    ) -> Self {
173        Self {
174            suite,
175            session_id,
176            master_secret: Zeroizing::new(*master_secret),
177            extended_ms,
178            common: ClientSessionCommon::new(
179                ticket,
180                time_now,
181                lifetime_secs,
182                server_cert_chain,
183                server_cert_verifier,
184                client_creds,
185            ),
186        }
187    }
188
189    pub(crate) fn master_secret(&self) -> &[u8; 48] {
190        &self.master_secret
191    }
192
193    pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
194        self.common.ticket.clone()
195    }
196
197    pub(crate) fn extended_ms(&self) -> bool {
198        self.extended_ms
199    }
200
201    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
202        self.suite
203    }
204
205    /// Test only: rewind epoch by `delta` seconds.
206    #[doc(hidden)]
207    pub fn rewind_epoch(&mut self, delta: u32) {
208        self.common.epoch -= delta as u64;
209    }
210}
211
212impl core::ops::Deref for Tls12ClientSessionValue {
213    type Target = ClientSessionCommon;
214
215    fn deref(&self) -> &Self::Target {
216        &self.common
217    }
218}
219
220#[derive(Debug, Clone)]
221pub struct ClientSessionCommon {
222    ticket: Arc<PayloadU16>,
223    epoch: u64,
224    lifetime_secs: u32,
225    server_cert_chain: Arc<CertificateChain<'static>>,
226    server_cert_verifier: Weak<dyn ServerCertVerifier>,
227    client_creds: Weak<dyn ResolvesClientCert>,
228}
229
230impl ClientSessionCommon {
231    fn new(
232        ticket: Arc<PayloadU16>,
233        time_now: UnixTime,
234        lifetime_secs: u32,
235        server_cert_chain: CertificateChain<'static>,
236        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
237        client_creds: &Arc<dyn ResolvesClientCert>,
238    ) -> Self {
239        Self {
240            ticket,
241            epoch: time_now.as_secs(),
242            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
243            server_cert_chain: Arc::new(server_cert_chain),
244            server_cert_verifier: Arc::downgrade(server_cert_verifier),
245            client_creds: Arc::downgrade(client_creds),
246        }
247    }
248
249    pub(crate) fn compatible_config(
250        &self,
251        server_cert_verifier: &Arc<dyn ServerCertVerifier>,
252        client_creds: &Arc<dyn ResolvesClientCert>,
253    ) -> bool {
254        let same_verifier = Weak::ptr_eq(
255            &Arc::downgrade(server_cert_verifier),
256            &self.server_cert_verifier,
257        );
258        let same_creds = Weak::ptr_eq(&Arc::downgrade(client_creds), &self.client_creds);
259
260        match (same_verifier, same_creds) {
261            (true, true) => true,
262            (false, _) => {
263                crate::log::trace!("resumption not allowed between different ServerCertVerifiers");
264                false
265            }
266            (_, _) => {
267                crate::log::trace!(
268                    "resumption not allowed between different ResolvesClientCert values"
269                );
270                false
271            }
272        }
273    }
274
275    pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
276        &self.server_cert_chain
277    }
278
279    pub(crate) fn ticket(&self) -> &[u8] {
280        self.ticket.0.as_ref()
281    }
282}
283
284static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
285
286/// This is the maximum allowed skew between server and client clocks, over
287/// the maximum ticket lifetime period.  This encompasses TCP retransmission
288/// times in case packet loss occurs when the client sends the ClientHello
289/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
290static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
291
292// --- Server types ---
293#[non_exhaustive]
294#[derive(Debug)]
295pub enum ServerSessionValue {
296    Tls12(Tls12ServerSessionValue),
297    Tls13(Tls13ServerSessionValue),
298}
299
300impl Codec<'_> for ServerSessionValue {
301    fn encode(&self, bytes: &mut Vec<u8>) {
302        match self {
303            Self::Tls12(value) => {
304                ProtocolVersion::TLSv1_2.encode(bytes);
305                value.encode(bytes);
306            }
307            Self::Tls13(value) => {
308                ProtocolVersion::TLSv1_3.encode(bytes);
309                value.encode(bytes);
310            }
311        }
312    }
313
314    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
315        match ProtocolVersion::read(r)? {
316            ProtocolVersion::TLSv1_2 => Ok(Self::Tls12(Tls12ServerSessionValue::read(r)?)),
317            ProtocolVersion::TLSv1_3 => Ok(Self::Tls13(Tls13ServerSessionValue::read(r)?)),
318            _ => Err(InvalidMessage::UnknownProtocolVersion),
319        }
320    }
321}
322
323#[derive(Debug)]
324pub struct Tls12ServerSessionValue {
325    #[doc(hidden)]
326    pub common: CommonServerSessionValue,
327    pub(crate) master_secret: Zeroizing<[u8; 48]>,
328    pub(crate) extended_ms: bool,
329}
330
331impl Tls12ServerSessionValue {
332    pub(crate) fn new(
333        common: CommonServerSessionValue,
334        master_secret: &[u8; 48],
335        extended_ms: bool,
336    ) -> Self {
337        Self {
338            common,
339            master_secret: Zeroizing::new(*master_secret),
340            extended_ms,
341        }
342    }
343}
344
345impl Codec<'_> for Tls12ServerSessionValue {
346    fn encode(&self, bytes: &mut Vec<u8>) {
347        self.common.encode(bytes);
348        bytes.extend_from_slice(self.master_secret.as_ref());
349        (self.extended_ms as u8).encode(bytes);
350    }
351
352    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
353        Ok(Self {
354            common: CommonServerSessionValue::read(r)?,
355            master_secret: Zeroizing::new(
356                match r
357                    .take(48)
358                    .and_then(|slice| slice.try_into().ok())
359                {
360                    Some(array) => array,
361                    None => return Err(InvalidMessage::MessageTooShort),
362                },
363            ),
364            extended_ms: matches!(u8::read(r)?, 1),
365        })
366    }
367}
368
369impl From<Tls12ServerSessionValue> for ServerSessionValue {
370    fn from(value: Tls12ServerSessionValue) -> Self {
371        Self::Tls12(value)
372    }
373}
374
375#[derive(Debug)]
376pub struct Tls13ServerSessionValue {
377    #[doc(hidden)]
378    pub common: CommonServerSessionValue,
379    pub(crate) secret: Zeroizing<PayloadU8>,
380    pub(crate) age_obfuscation_offset: u32,
381
382    // not encoded vv
383    freshness: Option<bool>,
384}
385
386impl Tls13ServerSessionValue {
387    pub(crate) fn new(
388        common: CommonServerSessionValue,
389        secret: &[u8],
390        age_obfuscation_offset: u32,
391    ) -> Self {
392        Self {
393            common,
394            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
395            age_obfuscation_offset,
396            freshness: None,
397        }
398    }
399
400    pub(crate) fn set_freshness(
401        mut self,
402        obfuscated_client_age_ms: u32,
403        time_now: UnixTime,
404    ) -> Self {
405        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
406        let server_age_ms = (time_now
407            .as_secs()
408            .saturating_sub(self.common.creation_time_sec) as u32)
409            .saturating_mul(1000);
410
411        let age_difference = server_age_ms.abs_diff(client_age_ms);
412
413        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
414        self
415    }
416
417    pub(crate) fn is_fresh(&self) -> bool {
418        self.freshness.unwrap_or_default()
419    }
420}
421
422impl Codec<'_> for Tls13ServerSessionValue {
423    fn encode(&self, bytes: &mut Vec<u8>) {
424        self.common.encode(bytes);
425        self.secret.encode(bytes);
426        self.age_obfuscation_offset
427            .encode(bytes);
428    }
429
430    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
431        Ok(Self {
432            common: CommonServerSessionValue::read(r)?,
433            secret: Zeroizing::new(PayloadU8::read(r)?),
434            age_obfuscation_offset: u32::read(r)?,
435            freshness: None,
436        })
437    }
438}
439
440impl From<Tls13ServerSessionValue> for ServerSessionValue {
441    fn from(value: Tls13ServerSessionValue) -> Self {
442        Self::Tls13(value)
443    }
444}
445
446#[derive(Debug)]
447pub struct CommonServerSessionValue {
448    pub(crate) sni: Option<DnsName<'static>>,
449    pub(crate) cipher_suite: CipherSuite,
450    pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
451    pub(crate) alpn: Option<ProtocolName>,
452    pub(crate) application_data: PayloadU16,
453    #[doc(hidden)]
454    pub creation_time_sec: u64,
455}
456
457impl CommonServerSessionValue {
458    pub(crate) fn new(
459        sni: Option<&DnsName<'_>>,
460        cipher_suite: CipherSuite,
461        client_cert_chain: Option<CertificateChain<'static>>,
462        alpn: Option<ProtocolName>,
463        application_data: Vec<u8>,
464        creation_time: UnixTime,
465    ) -> Self {
466        Self {
467            sni: sni.map(|s| s.to_owned()),
468            cipher_suite,
469            client_cert_chain,
470            alpn,
471            application_data: PayloadU16::new(application_data),
472            creation_time_sec: creation_time.as_secs(),
473        }
474    }
475}
476
477impl Codec<'_> for CommonServerSessionValue {
478    fn encode(&self, bytes: &mut Vec<u8>) {
479        if let Some(sni) = &self.sni {
480            1u8.encode(bytes);
481            let sni_bytes: &str = sni.as_ref();
482            PayloadU8::<MaybeEmpty>::encode_slice(sni_bytes.as_bytes(), bytes);
483        } else {
484            0u8.encode(bytes);
485        }
486        self.cipher_suite.encode(bytes);
487        if let Some(chain) = &self.client_cert_chain {
488            1u8.encode(bytes);
489            chain.encode(bytes);
490        } else {
491            0u8.encode(bytes);
492        }
493        if let Some(alpn) = &self.alpn {
494            1u8.encode(bytes);
495            alpn.encode(bytes);
496        } else {
497            0u8.encode(bytes);
498        }
499        self.application_data.encode(bytes);
500        self.creation_time_sec.encode(bytes);
501    }
502
503    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
504        let sni = match u8::read(r)? {
505            1 => {
506                let dns_name = PayloadU8::<MaybeEmpty>::read(r)?;
507                let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
508                    Ok(dns_name) => dns_name.to_owned(),
509                    Err(_) => return Err(InvalidMessage::InvalidServerName),
510                };
511
512                Some(dns_name)
513            }
514            _ => None,
515        };
516
517        Ok(Self {
518            sni,
519            cipher_suite: CipherSuite::read(r)?,
520            client_cert_chain: match u8::read(r)? {
521                1 => Some(CertificateChain::read(r)?.into_owned()),
522                _ => None,
523            },
524            alpn: match u8::read(r)? {
525                1 => Some(ProtocolName::read(r)?),
526                _ => None,
527            },
528            application_data: PayloadU16::read(r)?,
529            creation_time_sec: u64::read(r)?,
530        })
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[cfg(feature = "std")] // for UnixTime::now
539    #[test]
540    fn serversessionvalue_is_debug() {
541        use std::{println, vec};
542        let ssv = ServerSessionValue::Tls13(Tls13ServerSessionValue::new(
543            CommonServerSessionValue::new(
544                None,
545                CipherSuite::TLS13_AES_128_GCM_SHA256,
546                None,
547                None,
548                vec![4, 5, 6],
549                UnixTime::now(),
550            ),
551            &[1, 2, 3],
552            0x12345678,
553        ));
554        println!("{ssv:?}");
555        println!("{:#04x?}", ssv.get_encoding());
556    }
557
558    #[test]
559    fn serversessionvalue_no_sni() {
560        let bytes = [
561            0x03, 0x04, 0x00, 0x13, 0x01, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05, 0x06, 0x00, 0x00,
562            0x00, 0x00, 0x68, 0x6e, 0x94, 0x32, 0x03, 0x01, 0x02, 0x03, 0x12, 0x34, 0x56, 0x78,
563        ];
564        let mut rd = Reader::init(&bytes);
565        let ssv = ServerSessionValue::read(&mut rd).unwrap();
566        assert_eq!(ssv.get_encoding(), bytes);
567    }
568
569    #[test]
570    fn serversessionvalue_with_cert() {
571        let bytes = [
572            0x03, 0x04, 0x00, 0x13, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05,
573            0x06, 0x00, 0x00, 0x00, 0x00, 0x68, 0x6e, 0x94, 0x32, 0x03, 0x01, 0x02, 0x03, 0x12,
574            0x34, 0x56, 0x78,
575        ];
576        let mut rd = Reader::init(&bytes);
577        let ssv = ServerSessionValue::read(&mut rd).unwrap();
578        assert_eq!(ssv.get_encoding(), bytes);
579    }
580}