Skip to main content

rustls/client/
mod.rs

1use alloc::vec::Vec;
2use core::ops::Deref;
3use core::time::Duration;
4
5use pki_types::UnixTime;
6use zeroize::Zeroizing;
7
8use crate::crypto::cipher::Payload;
9use crate::crypto::{CipherSuite, CryptoProvider, Identity, SelectedCredential, SignatureScheme};
10use crate::enums::{ApplicationProtocol, CertificateType};
11use crate::error::{ApiMisuse, Error, InvalidMessage};
12use crate::log::{debug, trace};
13use crate::msgs::{
14    CertificateChain, Codec, ExtensionType, MaybeEmpty, NewSessionTicketPayloadTls13, Reader,
15    ServerExtensions, SessionId, SizedPayload,
16};
17use crate::sync::Arc;
18use crate::verify::DistinguishedName;
19#[cfg(feature = "webpki")]
20pub use crate::webpki::{
21    ServerVerifierBuilder, VerifierBuilderError, WebPkiServerVerifier,
22    verify_identity_signed_by_trust_anchor, verify_server_name,
23};
24use crate::{Tls12CipherSuite, Tls13CipherSuite, compress};
25
26mod config;
27pub use config::{
28    ClientConfig, ClientCredentialResolver, ClientSessionKey, ClientSessionStore,
29    CredentialRequest, Resumption, Tls12Resumption, WantsClientCert,
30};
31
32mod connection;
33pub use connection::{ClientConnection, ClientConnectionBuilder, ClientSide, WriteEarlyData};
34
35mod ech;
36pub use ech::{EchConfig, EchGreaseConfig, EchMode, EchStatus};
37
38mod handy;
39pub use handy::ClientSessionMemoryCache;
40
41mod hs;
42pub(crate) use hs::ClientHandler;
43
44mod tls12;
45pub(crate) use tls12::TLS12_HANDLER;
46
47mod tls13;
48pub(crate) use tls13::TLS13_HANDLER;
49
50/// Dangerous configuration that should be audited and used with extreme care.
51pub mod danger {
52    pub use super::config::danger::{DangerousClientConfig, DangerousClientConfigBuilder};
53    pub use crate::verify::{
54        HandshakeSignatureValid, PeerVerified, ServerIdentity, ServerVerifier,
55        SignatureVerificationInput,
56    };
57}
58
59#[cfg(test)]
60mod test;
61
62pub(crate) struct Retrieved<T> {
63    pub(crate) value: T,
64    retrieved_at: UnixTime,
65}
66
67impl<T> Retrieved<T> {
68    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
69        Self {
70            value,
71            retrieved_at,
72        }
73    }
74
75    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
76        Some(Retrieved {
77            value: f(&self.value)?,
78            retrieved_at: self.retrieved_at,
79        })
80    }
81}
82
83impl Retrieved<&Tls13Session> {
84    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
85        let age_secs = self
86            .retrieved_at
87            .as_secs()
88            .saturating_sub(self.value.common.epoch);
89        let age_millis = age_secs as u32 * 1000;
90        age_millis.wrapping_add(self.value.age_add)
91    }
92}
93
94impl<T: Deref<Target = ClientSessionCommon>> Retrieved<T> {
95    pub(crate) fn has_expired(&self) -> bool {
96        let common = &*self.value;
97        common.lifetime != Duration::ZERO
98            && common
99                .epoch
100                .saturating_add(common.lifetime.as_secs())
101                < self.retrieved_at.as_secs()
102    }
103}
104
105impl<T> Deref for Retrieved<T> {
106    type Target = T;
107
108    fn deref(&self) -> &Self::Target {
109        &self.value
110    }
111}
112
113/// A stored TLS 1.3 client session value.
114#[derive(Debug)]
115pub struct Tls13Session {
116    suite: &'static Tls13CipherSuite,
117    secret: Zeroizing<SizedPayload<'static, u8>>,
118    pub(crate) age_add: u32,
119    max_early_data_size: u32,
120    pub(crate) common: ClientSessionCommon,
121    quic_params: SizedPayload<'static, u16, MaybeEmpty>,
122}
123
124impl Tls13Session {
125    /// Decode a ticket from the given bytes.
126    pub fn from_slice(bytes: &[u8], provider: &CryptoProvider) -> Result<Self, Error> {
127        let mut reader = Reader::new(bytes);
128        let suite = CipherSuite::read(&mut reader)?;
129        let suite = provider
130            .tls13_cipher_suites
131            .iter()
132            .find(|s| s.common.suite == suite)
133            .ok_or(ApiMisuse::ResumingFromUnknownCipherSuite(suite))?;
134
135        Ok(Self {
136            suite: *suite,
137            secret: Zeroizing::new(SizedPayload::<u8>::read(&mut reader)?.into_owned()),
138            age_add: u32::read(&mut reader)?,
139            max_early_data_size: u32::read(&mut reader)?,
140            common: ClientSessionCommon::read(&mut reader)?,
141            quic_params: SizedPayload::<u16, MaybeEmpty>::read(&mut reader)?.into_owned(),
142        })
143    }
144
145    pub(crate) fn new(
146        ticket: &NewSessionTicketPayloadTls13,
147        input: Tls13ClientSessionInput,
148        secret: &[u8],
149        time_now: UnixTime,
150    ) -> Self {
151        Self {
152            suite: input.suite,
153            secret: Zeroizing::new(secret.to_vec().into()),
154            age_add: ticket.age_add,
155            max_early_data_size: ticket
156                .extensions
157                .max_early_data_size
158                .unwrap_or_default(),
159            common: ClientSessionCommon::new(
160                ticket.ticket.clone(),
161                time_now,
162                ticket.lifetime,
163                input.peer_identity,
164            ),
165            quic_params: input
166                .quic_params
167                .unwrap_or_else(|| SizedPayload::from(Payload::new(Vec::new()))),
168        }
169    }
170
171    /// Encode this ticket into `buf` for persistence.
172    pub fn encode(&self, buf: &mut Vec<u8>) {
173        self.suite.common.suite.encode(buf);
174        self.secret.encode(buf);
175        buf.extend_from_slice(&self.age_add.to_be_bytes());
176        buf.extend_from_slice(&self.max_early_data_size.to_be_bytes());
177        self.common.encode(buf);
178        self.quic_params.encode(buf);
179    }
180
181    /// Test only: replace `max_early_data_size` with `new`
182    #[doc(hidden)]
183    pub fn _reset_max_early_data_size(&mut self, expected: u32, desired: u32) {
184        assert_eq!(
185            self.max_early_data_size, expected,
186            "max_early_data_size was not expected value"
187        );
188        self.max_early_data_size = desired;
189    }
190
191    /// Test only: rewind epoch by `delta` seconds.
192    #[doc(hidden)]
193    pub fn rewind_epoch(&mut self, delta: u32) {
194        self.common.epoch -= delta as u64;
195    }
196}
197
198impl Deref for Tls13Session {
199    type Target = ClientSessionCommon;
200
201    fn deref(&self) -> &Self::Target {
202        &self.common
203    }
204}
205
206/// A "template" for future TLS1.3 client session values.
207#[derive(Clone)]
208pub(crate) struct Tls13ClientSessionInput {
209    pub(crate) suite: &'static Tls13CipherSuite,
210    pub(crate) peer_identity: Identity<'static>,
211    pub(crate) quic_params: Option<SizedPayload<'static, u16, MaybeEmpty>>,
212}
213
214/// A stored TLS 1.2 client session value.
215#[derive(Debug, Clone)]
216pub struct Tls12Session {
217    suite: &'static Tls12CipherSuite,
218    pub(crate) session_id: SessionId,
219    master_secret: Zeroizing<[u8; 48]>,
220    extended_ms: bool,
221    #[doc(hidden)]
222    pub(crate) common: ClientSessionCommon,
223}
224
225impl Tls12Session {
226    /// Decode a ticket from the given bytes.
227    pub fn from_slice(bytes: &[u8], provider: &CryptoProvider) -> Result<Self, Error> {
228        let mut reader = Reader::new(bytes);
229        let suite = CipherSuite::read(&mut reader)?;
230        let suite = provider
231            .tls12_cipher_suites
232            .iter()
233            .find(|s| s.common.suite == suite)
234            .ok_or(ApiMisuse::ResumingFromUnknownCipherSuite(suite))?;
235
236        Ok(Self {
237            suite: *suite,
238            session_id: SessionId::read(&mut reader)?,
239            master_secret: Zeroizing::new(
240                reader
241                    .take_array("MasterSecret")
242                    .copied()?,
243            ),
244            extended_ms: matches!(u8::read(&mut reader)?, 1),
245            common: ClientSessionCommon::read(&mut reader)?,
246        })
247    }
248
249    pub(crate) fn new(
250        suite: &'static Tls12CipherSuite,
251        session_id: SessionId,
252        ticket: Arc<SizedPayload<'static, u16, MaybeEmpty>>,
253        master_secret: &[u8; 48],
254        peer_identity: Identity<'static>,
255        time_now: UnixTime,
256        lifetime: Duration,
257        extended_ms: bool,
258    ) -> Self {
259        Self {
260            suite,
261            session_id,
262            master_secret: Zeroizing::new(*master_secret),
263            extended_ms,
264            common: ClientSessionCommon::new(ticket, time_now, lifetime, peer_identity),
265        }
266    }
267
268    /// Encode this ticket into `buf` for persistence.
269    pub fn encode(&self, buf: &mut Vec<u8>) {
270        self.suite.common.suite.encode(buf);
271        self.session_id.encode(buf);
272        buf.extend_from_slice(&*self.master_secret);
273        buf.push(self.extended_ms as u8);
274        self.common.encode(buf);
275    }
276
277    /// Test only: rewind epoch by `delta` seconds.
278    #[doc(hidden)]
279    pub fn rewind_epoch(&mut self, delta: u32) {
280        self.common.epoch -= delta as u64;
281    }
282}
283
284impl Deref for Tls12Session {
285    type Target = ClientSessionCommon;
286
287    fn deref(&self) -> &Self::Target {
288        &self.common
289    }
290}
291
292/// Common data for stored client sessions.
293#[derive(Debug, Clone)]
294pub struct ClientSessionCommon {
295    pub(crate) ticket: Arc<SizedPayload<'static, u16>>,
296    pub(crate) epoch: u64,
297    lifetime: Duration,
298    peer_identity: Arc<Identity<'static>>,
299}
300
301impl ClientSessionCommon {
302    pub(crate) fn new(
303        ticket: Arc<SizedPayload<'static, u16>>,
304        time_now: UnixTime,
305        lifetime: Duration,
306        peer_identity: Identity<'static>,
307    ) -> Self {
308        Self {
309            ticket,
310            epoch: time_now.as_secs(),
311            lifetime: Ord::min(lifetime, MAX_TICKET_LIFETIME),
312            peer_identity: Arc::new(peer_identity),
313        }
314    }
315
316    pub(crate) fn peer_identity(&self) -> &Identity<'static> {
317        &self.peer_identity
318    }
319
320    pub(crate) fn ticket(&self) -> &[u8] {
321        (*self.ticket).bytes()
322    }
323}
324
325impl<'a> Codec<'a> for ClientSessionCommon {
326    fn encode(&self, bytes: &mut Vec<u8>) {
327        self.ticket.encode(bytes);
328        bytes.extend_from_slice(&self.epoch.to_be_bytes());
329        bytes.extend_from_slice(&self.lifetime.as_secs().to_be_bytes());
330        self.peer_identity.encode(bytes);
331    }
332
333    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
334        Ok(Self {
335            ticket: Arc::new(SizedPayload::read(r)?.into_owned()),
336            epoch: u64::read(r)?,
337            lifetime: Duration::from_secs(u64::read(r)?),
338            peer_identity: Arc::new(Identity::read(r)?.into_owned()),
339        })
340    }
341}
342
343#[derive(Debug)]
344struct ServerCertDetails {
345    cert_chain: CertificateChain<'static>,
346    ocsp_response: Vec<u8>,
347}
348
349impl ServerCertDetails {
350    fn new(cert_chain: CertificateChain<'static>, ocsp_response: Vec<u8>) -> Self {
351        Self {
352            cert_chain,
353            ocsp_response,
354        }
355    }
356}
357
358struct ClientHelloDetails {
359    alpn_protocols: Vec<ApplicationProtocol<'static>>,
360    sent_extensions: Vec<ExtensionType>,
361    extension_order_seed: u16,
362    offered_cert_compression: bool,
363}
364
365impl ClientHelloDetails {
366    fn new(alpn_protocols: Vec<ApplicationProtocol<'static>>, extension_order_seed: u16) -> Self {
367        Self {
368            alpn_protocols,
369            sent_extensions: Vec::new(),
370            extension_order_seed,
371            offered_cert_compression: false,
372        }
373    }
374
375    fn server_sent_unsolicited_extensions(
376        &self,
377        received_exts: &ServerExtensions<'_>,
378        allowed_unsolicited: &[ExtensionType],
379    ) -> bool {
380        let mut extensions = received_exts.collect_used();
381        extensions.extend(
382            received_exts
383                .unknown_extensions
384                .iter()
385                .map(|ext| ExtensionType::from(*ext)),
386        );
387        for ext_type in extensions {
388            if !self.sent_extensions.contains(&ext_type) && !allowed_unsolicited.contains(&ext_type)
389            {
390                trace!("Unsolicited extension {ext_type:?}");
391                return true;
392            }
393        }
394
395        false
396    }
397}
398
399enum ClientAuthDetails {
400    /// Send an empty `Certificate` and no `CertificateVerify`.
401    Empty { auth_context_tls13: Option<Vec<u8>> },
402    /// Send a non-empty `Certificate` and a `CertificateVerify`.
403    Verify {
404        credentials: SelectedCredential,
405        auth_context_tls13: Option<Vec<u8>>,
406        compressor: Option<&'static dyn compress::CertCompressor>,
407    },
408}
409
410impl ClientAuthDetails {
411    fn resolve(
412        negotiated_type: CertificateType,
413        resolver: &dyn ClientCredentialResolver,
414        root_hint_subjects: Option<&[DistinguishedName]>,
415        signature_schemes: &[SignatureScheme],
416        auth_context_tls13: Option<Vec<u8>>,
417        compressor: Option<&'static dyn compress::CertCompressor>,
418    ) -> Self {
419        let server_hello = CredentialRequest {
420            negotiated_type,
421            root_hint_subjects: root_hint_subjects.unwrap_or_default(),
422            signature_schemes,
423        };
424
425        if let Some(credentials) = resolver.resolve(&server_hello) {
426            debug!("Attempting client auth");
427            return Self::Verify {
428                credentials,
429                auth_context_tls13,
430                compressor,
431            };
432        }
433
434        debug!("Client auth requested but no cert/sigscheme available");
435        Self::Empty { auth_context_tls13 }
436    }
437}
438
439static MAX_TICKET_LIFETIME: Duration = Duration::from_secs(7 * 24 * 60 * 60);