1use pki_types::ServerName;
2
3use crate::enums::SignatureScheme;
4use crate::msgs::persist;
5use crate::sync::Arc;
6use crate::{NamedGroup, client, sign};
7
8#[derive(Debug)]
10pub(super) struct NoClientSessionStorage;
11
12impl client::ClientSessionStore for NoClientSessionStorage {
13 fn set_kx_hint(&self, _: ServerName<'static>, _: NamedGroup) {}
14
15 fn kx_hint(&self, _: &ServerName<'_>) -> Option<NamedGroup> {
16 None
17 }
18
19 fn set_tls12_session(&self, _: ServerName<'static>, _: persist::Tls12ClientSessionValue) {}
20
21 fn tls12_session(&self, _: &ServerName<'_>) -> Option<persist::Tls12ClientSessionValue> {
22 None
23 }
24
25 fn remove_tls12_session(&self, _: &ServerName<'_>) {}
26
27 fn insert_tls13_ticket(&self, _: ServerName<'static>, _: persist::Tls13ClientSessionValue) {}
28
29 fn take_tls13_ticket(&self, _: &ServerName<'_>) -> Option<persist::Tls13ClientSessionValue> {
30 None
31 }
32}
33
34#[cfg(any(feature = "std", feature = "hashbrown"))]
35mod cache {
36 use alloc::collections::VecDeque;
37 use core::fmt;
38
39 use pki_types::ServerName;
40
41 use crate::lock::Mutex;
42 use crate::msgs::persist;
43 use crate::{NamedGroup, limited_cache};
44
45 const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
46
47 struct ServerData {
48 kx_hint: Option<NamedGroup>,
49
50 tls12: Option<persist::Tls12ClientSessionValue>,
52
53 tls13: VecDeque<persist::Tls13ClientSessionValue>,
55 }
56
57 impl Default for ServerData {
58 fn default() -> Self {
59 Self {
60 kx_hint: None,
61 tls12: None,
62 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
63 }
64 }
65 }
66
67 pub struct ClientSessionMemoryCache {
72 servers: Mutex<limited_cache::LimitedCache<ServerName<'static>, ServerData>>,
73 }
74
75 impl ClientSessionMemoryCache {
76 #[cfg(feature = "std")]
79 pub fn new(size: usize) -> Self {
80 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
81 / MAX_TLS13_TICKETS_PER_SERVER;
82 Self {
83 servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
84 }
85 }
86
87 #[cfg(not(feature = "std"))]
90 pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Self {
91 let max_servers = size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1)
92 / MAX_TLS13_TICKETS_PER_SERVER;
93 Self {
94 servers: Mutex::new::<M>(limited_cache::LimitedCache::new(max_servers)),
95 }
96 }
97 }
98
99 impl super::client::ClientSessionStore for ClientSessionMemoryCache {
100 fn set_kx_hint(&self, server_name: ServerName<'static>, group: NamedGroup) {
101 self.servers
102 .lock()
103 .unwrap()
104 .get_or_insert_default_and_edit(server_name, |data| data.kx_hint = Some(group));
105 }
106
107 fn kx_hint(&self, server_name: &ServerName<'_>) -> Option<NamedGroup> {
108 self.servers
109 .lock()
110 .unwrap()
111 .get(server_name)
112 .and_then(|sd| sd.kx_hint)
113 }
114
115 fn set_tls12_session(
116 &self,
117 _server_name: ServerName<'static>,
118 _value: persist::Tls12ClientSessionValue,
119 ) {
120 self.servers
121 .lock()
122 .unwrap()
123 .get_or_insert_default_and_edit(_server_name.clone(), |data| {
124 data.tls12 = Some(_value)
125 });
126 }
127
128 fn tls12_session(
129 &self,
130 _server_name: &ServerName<'_>,
131 ) -> Option<persist::Tls12ClientSessionValue> {
132 self.servers
133 .lock()
134 .unwrap()
135 .get(_server_name)
136 .and_then(|sd| sd.tls12.as_ref().cloned())
137 }
138
139 fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {
140 self.servers
141 .lock()
142 .unwrap()
143 .get_mut(_server_name)
144 .and_then(|data| data.tls12.take());
145 }
146
147 fn insert_tls13_ticket(
148 &self,
149 server_name: ServerName<'static>,
150 value: persist::Tls13ClientSessionValue,
151 ) {
152 self.servers
153 .lock()
154 .unwrap()
155 .get_or_insert_default_and_edit(server_name.clone(), |data| {
156 if data.tls13.len() == data.tls13.capacity() {
157 data.tls13.pop_front();
158 }
159 data.tls13.push_back(value);
160 });
161 }
162
163 fn take_tls13_ticket(
164 &self,
165 server_name: &ServerName<'static>,
166 ) -> Option<persist::Tls13ClientSessionValue> {
167 self.servers
168 .lock()
169 .unwrap()
170 .get_mut(server_name)
171 .and_then(|data| data.tls13.pop_back())
172 }
173 }
174
175 impl fmt::Debug for ClientSessionMemoryCache {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 f.debug_struct("ClientSessionMemoryCache")
179 .finish()
180 }
181 }
182}
183
184#[cfg(any(feature = "std", feature = "hashbrown"))]
185pub use cache::ClientSessionMemoryCache;
186
187#[derive(Debug)]
188pub(super) struct FailResolveClientCert {}
189
190impl client::ResolvesClientCert for FailResolveClientCert {
191 fn resolve(
192 &self,
193 _root_hint_subjects: &[&[u8]],
194 _sigschemes: &[SignatureScheme],
195 ) -> Option<Arc<sign::CertifiedKey>> {
196 None
197 }
198
199 fn has_certs(&self) -> bool {
200 false
201 }
202}
203
204#[derive(Clone, Debug)]
209pub struct AlwaysResolvesClientRawPublicKeys(Arc<sign::CertifiedKey>);
210impl AlwaysResolvesClientRawPublicKeys {
211 pub fn new(certified_key: Arc<sign::CertifiedKey>) -> Self {
213 Self(certified_key)
214 }
215}
216
217impl client::ResolvesClientCert for AlwaysResolvesClientRawPublicKeys {
218 fn resolve(
219 &self,
220 _root_hint_subjects: &[&[u8]],
221 _sigschemes: &[SignatureScheme],
222 ) -> Option<Arc<sign::CertifiedKey>> {
223 Some(self.0.clone())
224 }
225
226 fn only_raw_public_keys(&self) -> bool {
227 true
228 }
229
230 fn has_certs(&self) -> bool {
235 true
236 }
237}
238
239#[cfg(test)]
240#[macro_rules_attribute::apply(test_for_each_provider)]
241mod tests {
242 use std::prelude::v1::*;
243
244 use pki_types::{ServerName, UnixTime};
245
246 use super::NoClientSessionStorage;
247 use super::provider::cipher_suite;
248 use crate::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
249 use crate::client::{ClientSessionStore, ResolvesClientCert};
250 use crate::msgs::base::PayloadU16;
251 use crate::msgs::enums::NamedGroup;
252 use crate::msgs::handshake::{CertificateChain, SessionId};
253 use crate::msgs::persist::Tls13ClientSessionValue;
254 use crate::pki_types::CertificateDer;
255 use crate::suites::SupportedCipherSuite;
256 use crate::sync::Arc;
257 use crate::{DigitallySignedStruct, Error, SignatureScheme, sign};
258
259 #[test]
260 fn test_noclientsessionstorage_does_nothing() {
261 let c = NoClientSessionStorage {};
262 let name = ServerName::try_from("example.com").unwrap();
263 let now = UnixTime::now();
264 let server_cert_verifier: Arc<dyn ServerCertVerifier> = Arc::new(DummyServerCertVerifier);
265 let resolves_client_cert: Arc<dyn ResolvesClientCert> = Arc::new(DummyResolvesClientCert);
266
267 c.set_kx_hint(name.clone(), NamedGroup::X25519);
268 assert_eq!(None, c.kx_hint(&name));
269
270 {
271 use crate::msgs::persist::Tls12ClientSessionValue;
272 let SupportedCipherSuite::Tls12(tls12_suite) =
273 cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
274 else {
275 unreachable!()
276 };
277
278 c.set_tls12_session(
279 name.clone(),
280 Tls12ClientSessionValue::new(
281 tls12_suite,
282 SessionId::empty(),
283 Arc::new(PayloadU16::empty()),
284 &[0u8; 48],
285 CertificateChain::default(),
286 &server_cert_verifier,
287 &resolves_client_cert,
288 now,
289 0,
290 true,
291 ),
292 );
293 assert!(c.tls12_session(&name).is_none());
294 c.remove_tls12_session(&name);
295 }
296
297 let SupportedCipherSuite::Tls13(tls13_suite) = cipher_suite::TLS13_AES_256_GCM_SHA384
298 else {
299 unreachable!();
300 };
301 c.insert_tls13_ticket(
302 name.clone(),
303 Tls13ClientSessionValue::new(
304 tls13_suite,
305 Arc::new(PayloadU16::empty()),
306 &[],
307 CertificateChain::default(),
308 &server_cert_verifier,
309 &resolves_client_cert,
310 now,
311 0,
312 0,
313 0,
314 ),
315 );
316 assert!(c.take_tls13_ticket(&name).is_none());
317 }
318
319 #[derive(Debug)]
320 struct DummyServerCertVerifier;
321
322 impl ServerCertVerifier for DummyServerCertVerifier {
323 #[cfg_attr(coverage_nightly, coverage(off))]
324 fn verify_server_cert(
325 &self,
326 _end_entity: &CertificateDer<'_>,
327 _intermediates: &[CertificateDer<'_>],
328 _server_name: &ServerName<'_>,
329 _ocsp_response: &[u8],
330 _now: UnixTime,
331 ) -> Result<ServerCertVerified, Error> {
332 unreachable!()
333 }
334
335 #[cfg_attr(coverage_nightly, coverage(off))]
336 fn verify_tls12_signature(
337 &self,
338 _message: &[u8],
339 _cert: &CertificateDer<'_>,
340 _dss: &DigitallySignedStruct,
341 ) -> Result<HandshakeSignatureValid, Error> {
342 unreachable!()
343 }
344
345 #[cfg_attr(coverage_nightly, coverage(off))]
346 fn verify_tls13_signature(
347 &self,
348 _message: &[u8],
349 _cert: &CertificateDer<'_>,
350 _dss: &DigitallySignedStruct,
351 ) -> Result<HandshakeSignatureValid, Error> {
352 unreachable!()
353 }
354
355 #[cfg_attr(coverage_nightly, coverage(off))]
356 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
357 unreachable!()
358 }
359
360 #[cfg_attr(coverage_nightly, coverage(off))]
361 fn request_ocsp_response(&self) -> bool {
362 unreachable!()
363 }
364 }
365
366 #[derive(Debug)]
367 struct DummyResolvesClientCert;
368
369 impl ResolvesClientCert for DummyResolvesClientCert {
370 #[cfg_attr(coverage_nightly, coverage(off))]
371 fn resolve(
372 &self,
373 _root_hint_subjects: &[&[u8]],
374 _sigschemes: &[SignatureScheme],
375 ) -> Option<Arc<sign::CertifiedKey>> {
376 unreachable!()
377 }
378
379 #[cfg_attr(coverage_nightly, coverage(off))]
380 fn has_certs(&self) -> bool {
381 unreachable!()
382 }
383 }
384}