rustls/msgs/
persist.rs

1use alloc::vec::Vec;
2use core::cmp;
3use core::time::Duration;
4
5use pki_types::{DnsName, UnixTime};
6use zeroize::Zeroizing;
7
8use crate::crypto::{CipherSuite, Identity};
9use crate::enums::ProtocolVersion;
10use crate::error::InvalidMessage;
11use crate::msgs::base::{MaybeEmpty, PayloadU8, PayloadU16};
12use crate::msgs::codec::{Codec, Reader};
13use crate::msgs::handshake::{ProtocolName, SessionId};
14use crate::sync::Arc;
15use crate::tls12::Tls12CipherSuite;
16use crate::tls13::Tls13CipherSuite;
17
18pub(crate) struct Retrieved<T> {
19    pub(crate) value: T,
20    retrieved_at: UnixTime,
21}
22
23impl<T> Retrieved<T> {
24    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
25        Self {
26            value,
27            retrieved_at,
28        }
29    }
30
31    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
32        Some(Retrieved {
33            value: f(&self.value)?,
34            retrieved_at: self.retrieved_at,
35        })
36    }
37}
38
39impl Retrieved<&Tls13ClientSessionValue> {
40    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
41        let age_secs = self
42            .retrieved_at
43            .as_secs()
44            .saturating_sub(self.value.common.epoch);
45        let age_millis = age_secs as u32 * 1000;
46        age_millis.wrapping_add(self.value.age_add)
47    }
48}
49
50impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
51    pub(crate) fn has_expired(&self) -> bool {
52        let common = &*self.value;
53        common.lifetime != Duration::ZERO
54            && common
55                .epoch
56                .saturating_add(common.lifetime.as_secs())
57                < self.retrieved_at.as_secs()
58    }
59}
60
61impl<T> core::ops::Deref for Retrieved<T> {
62    type Target = T;
63
64    fn deref(&self) -> &Self::Target {
65        &self.value
66    }
67}
68
69#[derive(Debug)]
70pub struct Tls13ClientSessionValue {
71    suite: &'static Tls13CipherSuite,
72    secret: Zeroizing<PayloadU8>,
73    age_add: u32,
74    max_early_data_size: u32,
75    pub(crate) common: ClientSessionCommon,
76    quic_params: PayloadU16,
77}
78
79impl Tls13ClientSessionValue {
80    pub(crate) fn new(
81        suite: &'static Tls13CipherSuite,
82        ticket: Arc<PayloadU16>,
83        secret: &[u8],
84        peer_identity: Identity<'static>,
85        time_now: UnixTime,
86        lifetime: Duration,
87        age_add: u32,
88        max_early_data_size: u32,
89    ) -> Self {
90        Self {
91            suite,
92            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
93            age_add,
94            max_early_data_size,
95            common: ClientSessionCommon::new(ticket, time_now, lifetime, peer_identity),
96            quic_params: PayloadU16::new(Vec::new()),
97        }
98    }
99
100    pub(crate) fn secret(&self) -> &[u8] {
101        self.secret.0.as_ref()
102    }
103
104    pub fn max_early_data_size(&self) -> u32 {
105        self.max_early_data_size
106    }
107
108    pub fn suite(&self) -> &'static Tls13CipherSuite {
109        self.suite
110    }
111
112    /// Test only: rewind epoch by `delta` seconds.
113    #[doc(hidden)]
114    pub fn rewind_epoch(&mut self, delta: u32) {
115        self.common.epoch -= delta as u64;
116    }
117
118    /// Test only: replace `max_early_data_size` with `new`
119    #[doc(hidden)]
120    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
121        self.max_early_data_size = new;
122    }
123
124    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
125        self.quic_params = PayloadU16::new(quic_params.to_vec());
126    }
127
128    pub fn quic_params(&self) -> Vec<u8> {
129        self.quic_params.0.clone()
130    }
131}
132
133impl core::ops::Deref for Tls13ClientSessionValue {
134    type Target = ClientSessionCommon;
135
136    fn deref(&self) -> &Self::Target {
137        &self.common
138    }
139}
140
141#[derive(Debug, Clone)]
142pub struct Tls12ClientSessionValue {
143    suite: &'static Tls12CipherSuite,
144    pub(crate) session_id: SessionId,
145    master_secret: Zeroizing<[u8; 48]>,
146    extended_ms: bool,
147    #[doc(hidden)]
148    pub(crate) common: ClientSessionCommon,
149}
150
151impl Tls12ClientSessionValue {
152    pub(crate) fn new(
153        suite: &'static Tls12CipherSuite,
154        session_id: SessionId,
155        ticket: Arc<PayloadU16>,
156        master_secret: &[u8; 48],
157        peer_identity: Identity<'static>,
158        time_now: UnixTime,
159        lifetime: Duration,
160        extended_ms: bool,
161    ) -> Self {
162        Self {
163            suite,
164            session_id,
165            master_secret: Zeroizing::new(*master_secret),
166            extended_ms,
167            common: ClientSessionCommon::new(ticket, time_now, lifetime, peer_identity),
168        }
169    }
170
171    pub(crate) fn master_secret(&self) -> &[u8; 48] {
172        &self.master_secret
173    }
174
175    pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
176        self.common.ticket.clone()
177    }
178
179    pub(crate) fn extended_ms(&self) -> bool {
180        self.extended_ms
181    }
182
183    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
184        self.suite
185    }
186
187    /// Test only: rewind epoch by `delta` seconds.
188    #[doc(hidden)]
189    pub fn rewind_epoch(&mut self, delta: u32) {
190        self.common.epoch -= delta as u64;
191    }
192}
193
194impl core::ops::Deref for Tls12ClientSessionValue {
195    type Target = ClientSessionCommon;
196
197    fn deref(&self) -> &Self::Target {
198        &self.common
199    }
200}
201
202#[derive(Debug, Clone)]
203pub struct ClientSessionCommon {
204    ticket: Arc<PayloadU16>,
205    epoch: u64,
206    lifetime: Duration,
207    peer_identity: Arc<Identity<'static>>,
208}
209
210impl ClientSessionCommon {
211    fn new(
212        ticket: Arc<PayloadU16>,
213        time_now: UnixTime,
214        lifetime: Duration,
215        peer_identity: Identity<'static>,
216    ) -> Self {
217        Self {
218            ticket,
219            epoch: time_now.as_secs(),
220            lifetime: cmp::min(lifetime, MAX_TICKET_LIFETIME),
221            peer_identity: Arc::new(peer_identity),
222        }
223    }
224
225    pub(crate) fn peer_identity(&self) -> &Identity<'static> {
226        &self.peer_identity
227    }
228
229    pub(crate) fn ticket(&self) -> &[u8] {
230        self.ticket.0.as_ref()
231    }
232}
233
234static MAX_TICKET_LIFETIME: Duration = Duration::from_secs(7 * 24 * 60 * 60);
235
236/// This is the maximum allowed skew between server and client clocks, over
237/// the maximum ticket lifetime period.  This encompasses TCP retransmission
238/// times in case packet loss occurs when the client sends the ClientHello
239/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
240static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
241
242// --- Server types ---
243#[non_exhaustive]
244#[derive(Debug)]
245pub enum ServerSessionValue {
246    Tls12(Tls12ServerSessionValue),
247    Tls13(Tls13ServerSessionValue),
248}
249
250impl Codec<'_> for ServerSessionValue {
251    fn encode(&self, bytes: &mut Vec<u8>) {
252        match self {
253            Self::Tls12(value) => {
254                ProtocolVersion::TLSv1_2.encode(bytes);
255                value.encode(bytes);
256            }
257            Self::Tls13(value) => {
258                ProtocolVersion::TLSv1_3.encode(bytes);
259                value.encode(bytes);
260            }
261        }
262    }
263
264    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
265        match ProtocolVersion::read(r)? {
266            ProtocolVersion::TLSv1_2 => Ok(Self::Tls12(Tls12ServerSessionValue::read(r)?)),
267            ProtocolVersion::TLSv1_3 => Ok(Self::Tls13(Tls13ServerSessionValue::read(r)?)),
268            _ => Err(InvalidMessage::UnknownProtocolVersion),
269        }
270    }
271}
272
273#[derive(Debug)]
274pub struct Tls12ServerSessionValue {
275    #[doc(hidden)]
276    pub common: CommonServerSessionValue,
277    pub(crate) master_secret: Zeroizing<[u8; 48]>,
278    pub(crate) extended_ms: bool,
279}
280
281impl Tls12ServerSessionValue {
282    pub(crate) fn new(
283        common: CommonServerSessionValue,
284        master_secret: &[u8; 48],
285        extended_ms: bool,
286    ) -> Self {
287        Self {
288            common,
289            master_secret: Zeroizing::new(*master_secret),
290            extended_ms,
291        }
292    }
293}
294
295impl Codec<'_> for Tls12ServerSessionValue {
296    fn encode(&self, bytes: &mut Vec<u8>) {
297        self.common.encode(bytes);
298        bytes.extend_from_slice(self.master_secret.as_ref());
299        (self.extended_ms as u8).encode(bytes);
300    }
301
302    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
303        Ok(Self {
304            common: CommonServerSessionValue::read(r)?,
305            master_secret: Zeroizing::new(
306                match r
307                    .take(48)
308                    .and_then(|slice| slice.try_into().ok())
309                {
310                    Some(array) => array,
311                    None => return Err(InvalidMessage::MessageTooShort),
312                },
313            ),
314            extended_ms: matches!(u8::read(r)?, 1),
315        })
316    }
317}
318
319impl From<Tls12ServerSessionValue> for ServerSessionValue {
320    fn from(value: Tls12ServerSessionValue) -> Self {
321        Self::Tls12(value)
322    }
323}
324
325#[derive(Debug)]
326pub struct Tls13ServerSessionValue {
327    #[doc(hidden)]
328    pub common: CommonServerSessionValue,
329    pub(crate) secret: Zeroizing<PayloadU8>,
330    pub(crate) age_obfuscation_offset: u32,
331
332    // not encoded vv
333    freshness: Option<bool>,
334}
335
336impl Tls13ServerSessionValue {
337    pub(crate) fn new(
338        common: CommonServerSessionValue,
339        secret: &[u8],
340        age_obfuscation_offset: u32,
341    ) -> Self {
342        Self {
343            common,
344            secret: Zeroizing::new(PayloadU8::new(secret.to_vec())),
345            age_obfuscation_offset,
346            freshness: None,
347        }
348    }
349
350    pub(crate) fn set_freshness(
351        mut self,
352        obfuscated_client_age_ms: u32,
353        time_now: UnixTime,
354    ) -> Self {
355        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
356        let server_age_ms = (time_now
357            .as_secs()
358            .saturating_sub(self.common.creation_time_sec) as u32)
359            .saturating_mul(1000);
360
361        let age_difference = server_age_ms.abs_diff(client_age_ms);
362
363        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
364        self
365    }
366
367    pub(crate) fn is_fresh(&self) -> bool {
368        self.freshness.unwrap_or_default()
369    }
370}
371
372impl Codec<'_> for Tls13ServerSessionValue {
373    fn encode(&self, bytes: &mut Vec<u8>) {
374        self.common.encode(bytes);
375        self.secret.encode(bytes);
376        self.age_obfuscation_offset
377            .encode(bytes);
378    }
379
380    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
381        Ok(Self {
382            common: CommonServerSessionValue::read(r)?,
383            secret: Zeroizing::new(PayloadU8::read(r)?),
384            age_obfuscation_offset: u32::read(r)?,
385            freshness: None,
386        })
387    }
388}
389
390impl From<Tls13ServerSessionValue> for ServerSessionValue {
391    fn from(value: Tls13ServerSessionValue) -> Self {
392        Self::Tls13(value)
393    }
394}
395
396#[derive(Debug)]
397pub struct CommonServerSessionValue {
398    pub(crate) sni: Option<DnsName<'static>>,
399    pub(crate) cipher_suite: CipherSuite,
400    pub(crate) peer_identity: Option<Identity<'static>>,
401    pub(crate) alpn: Option<ProtocolName>,
402    pub(crate) application_data: PayloadU16,
403    #[doc(hidden)]
404    pub creation_time_sec: u64,
405}
406
407impl CommonServerSessionValue {
408    pub(crate) fn new(
409        sni: Option<&DnsName<'_>>,
410        cipher_suite: CipherSuite,
411        peer_identity: Option<Identity<'static>>,
412        alpn: Option<ProtocolName>,
413        application_data: Vec<u8>,
414        creation_time: UnixTime,
415    ) -> Self {
416        Self {
417            sni: sni.map(|s| s.to_owned()),
418            cipher_suite,
419            peer_identity,
420            alpn,
421            application_data: PayloadU16::new(application_data),
422            creation_time_sec: creation_time.as_secs(),
423        }
424    }
425
426    pub(crate) fn can_resume(&self, suite: CipherSuite, sni: &Option<DnsName<'_>>) -> bool {
427        // The RFCs underspecify what happens if we try to resume to
428        // an unoffered/varying suite.  We merely don't resume in weird cases.
429        //
430        // RFC 6066 says "A server that implements this extension MUST NOT accept
431        // the request to resume the session if the server_name extension contains
432        // a different name. Instead, it proceeds with a full handshake to
433        // establish a new session."
434        //
435        // RFC 8446: "The server MUST ensure that it selects
436        // a compatible PSK (if any) and cipher suite."
437        self.cipher_suite == suite && &self.sni == sni
438    }
439}
440
441impl Codec<'_> for CommonServerSessionValue {
442    fn encode(&self, bytes: &mut Vec<u8>) {
443        if let Some(sni) = &self.sni {
444            1u8.encode(bytes);
445            let sni_bytes: &str = sni.as_ref();
446            PayloadU8::<MaybeEmpty>::encode_slice(sni_bytes.as_bytes(), bytes);
447        } else {
448            0u8.encode(bytes);
449        }
450        self.cipher_suite.encode(bytes);
451        if let Some(identity) = &self.peer_identity {
452            1u8.encode(bytes);
453            identity.encode(bytes);
454        } else {
455            0u8.encode(bytes);
456        }
457        if let Some(alpn) = &self.alpn {
458            1u8.encode(bytes);
459            alpn.encode(bytes);
460        } else {
461            0u8.encode(bytes);
462        }
463        self.application_data.encode(bytes);
464        self.creation_time_sec.encode(bytes);
465    }
466
467    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
468        let sni = match u8::read(r)? {
469            1 => {
470                let dns_name = PayloadU8::<MaybeEmpty>::read(r)?;
471                let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
472                    Ok(dns_name) => dns_name.to_owned(),
473                    Err(_) => return Err(InvalidMessage::InvalidServerName),
474                };
475
476                Some(dns_name)
477            }
478            _ => None,
479        };
480
481        Ok(Self {
482            sni,
483            cipher_suite: CipherSuite::read(r)?,
484            peer_identity: match u8::read(r)? {
485                1 => Some(Identity::read(r)?.into_owned()),
486                _ => None,
487            },
488            alpn: match u8::read(r)? {
489                1 => Some(ProtocolName::read(r)?),
490                _ => None,
491            },
492            application_data: PayloadU16::read(r)?,
493            creation_time_sec: u64::read(r)?,
494        })
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use pki_types::CertificateDer;
501
502    use super::*;
503    use crate::crypto::CertificateIdentity;
504
505    #[cfg(feature = "std")] // for UnixTime::now
506    #[test]
507    fn serversessionvalue_is_debug() {
508        use std::{println, vec};
509        let ssv = ServerSessionValue::Tls13(Tls13ServerSessionValue::new(
510            CommonServerSessionValue::new(
511                None,
512                CipherSuite::TLS13_AES_128_GCM_SHA256,
513                None,
514                None,
515                vec![4, 5, 6],
516                UnixTime::now(),
517            ),
518            &[1, 2, 3],
519            0x12345678,
520        ));
521        println!("{ssv:?}");
522        println!("{:#04x?}", ssv.get_encoding());
523    }
524
525    #[test]
526    fn serversessionvalue_no_sni() {
527        let bytes = [
528            0x03, 0x04, 0x00, 0x13, 0x01, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05, 0x06, 0x00, 0x00,
529            0x00, 0x00, 0x68, 0x6e, 0x94, 0x32, 0x03, 0x01, 0x02, 0x03, 0x12, 0x34, 0x56, 0x78,
530        ];
531        let mut rd = Reader::init(&bytes);
532        let ssv = ServerSessionValue::read(&mut rd).unwrap();
533        assert_eq!(ssv.get_encoding(), bytes);
534    }
535
536    #[test]
537    fn serversessionvalue_with_cert() {
538        std::eprintln!(
539            "{:#04x?}",
540            ServerSessionValue::Tls13(Tls13ServerSessionValue::new(
541                CommonServerSessionValue::new(
542                    None,
543                    CipherSuite::TLS13_AES_128_GCM_SHA256,
544                    Some(Identity::X509(CertificateIdentity {
545                        end_entity: CertificateDer::from(&[10, 11, 12][..]),
546                        intermediates: alloc::vec![],
547                    })),
548                    None,
549                    alloc::vec![4, 5, 6],
550                    UnixTime::now(),
551                ),
552                &[1, 2, 3],
553                0x12345678,
554            ))
555            .get_encoding()
556        );
557
558        let bytes = [
559            0x03, 0x04, 0x00, 0x13, 0x01, 0x01, 0x00, 0x00, 0x00, 0x03, 0x0a, 0x0b, 0x0c, 0x00,
560            0x00, 0x00, 0x00, 0x00, 0x03, 0x04, 0x05, 0x06, 0x00, 0x00, 0x00, 0x00, 0x68, 0xc1,
561            0x99, 0xac, 0x03, 0x01, 0x02, 0x03, 0x12, 0x34, 0x56, 0x78,
562        ];
563        let mut rd = Reader::init(&bytes);
564        let ssv = ServerSessionValue::read(&mut rd).unwrap();
565        assert_eq!(ssv.get_encoding(), bytes);
566    }
567}