1#[cfg(feature = "std")]
36use alloc::collections::VecDeque;
37use alloc::vec::Vec;
38use core::fmt::Debug;
39#[cfg(feature = "std")]
40use std::sync::Mutex;
41
42use crate::enums::CertificateCompressionAlgorithm;
43use crate::msgs::base::{Payload, PayloadU24};
44use crate::msgs::codec::Codec;
45use crate::msgs::handshake::{CertificatePayloadTls13, CompressedCertificatePayload};
46use crate::sync::Arc;
47
48pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
51 &[
52 #[cfg(feature = "brotli")]
53 BROTLI_DECOMPRESSOR,
54 #[cfg(feature = "zlib")]
55 ZLIB_DECOMPRESSOR,
56 ]
57}
58
59pub trait CertDecompressor: Debug + Send + Sync {
61 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
68
69 fn algorithm(&self) -> CertificateCompressionAlgorithm;
71}
72
73pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
76 &[
77 #[cfg(feature = "brotli")]
78 BROTLI_COMPRESSOR,
79 #[cfg(feature = "zlib")]
80 ZLIB_COMPRESSOR,
81 ]
82}
83
84pub trait CertCompressor: Debug + Send + Sync {
86 fn compress(
95 &self,
96 input: Vec<u8>,
97 level: CompressionLevel,
98 ) -> Result<Vec<u8>, CompressionFailed>;
99
100 fn algorithm(&self) -> CertificateCompressionAlgorithm;
102}
103
104#[non_exhaustive]
106#[derive(Debug, Copy, Clone, Eq, PartialEq)]
107pub enum CompressionLevel {
108 Interactive,
112
113 Amortized,
117}
118
119#[allow(clippy::exhaustive_structs)]
121#[derive(Debug)]
122pub struct DecompressionFailed;
123
124#[allow(clippy::exhaustive_structs)]
126#[derive(Debug)]
127pub struct CompressionFailed;
128
129#[cfg(feature = "zlib")]
130mod feat_zlib_rs {
131 use zlib_rs::c_api::Z_BEST_COMPRESSION;
132 use zlib_rs::{ReturnCode, deflate, inflate};
133
134 use super::*;
135
136 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
138
139 #[derive(Debug)]
140 struct ZlibRsDecompressor;
141
142 impl CertDecompressor for ZlibRsDecompressor {
143 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
144 let output_len = output.len();
145 match inflate::uncompress_slice(output, input, inflate::InflateConfig::default()) {
146 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
147 (_, _) => Err(DecompressionFailed),
148 }
149 }
150
151 fn algorithm(&self) -> CertificateCompressionAlgorithm {
152 CertificateCompressionAlgorithm::Zlib
153 }
154 }
155
156 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
158
159 #[derive(Debug)]
160 struct ZlibRsCompressor;
161
162 impl CertCompressor for ZlibRsCompressor {
163 fn compress(
164 &self,
165 input: Vec<u8>,
166 level: CompressionLevel,
167 ) -> Result<Vec<u8>, CompressionFailed> {
168 let mut output = alloc::vec![0u8; deflate::compress_bound(input.len())];
169 let config = match level {
170 CompressionLevel::Interactive => deflate::DeflateConfig::default(),
171 CompressionLevel::Amortized => deflate::DeflateConfig::new(Z_BEST_COMPRESSION),
172 };
173 let (output_filled, rc) = deflate::compress_slice(&mut output, &input, config);
174 if rc != ReturnCode::Ok {
175 return Err(CompressionFailed);
176 }
177
178 let used = output_filled.len();
179 output.truncate(used);
180 Ok(output)
181 }
182
183 fn algorithm(&self) -> CertificateCompressionAlgorithm {
184 CertificateCompressionAlgorithm::Zlib
185 }
186 }
187}
188
189#[cfg(feature = "zlib")]
190pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
191
192#[cfg(feature = "brotli")]
193mod feat_brotli {
194 use std::io::{Cursor, Write};
195
196 use super::*;
197
198 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
200
201 #[derive(Debug)]
202 struct BrotliDecompressor;
203
204 impl CertDecompressor for BrotliDecompressor {
205 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
206 let mut in_cursor = Cursor::new(input);
207 let mut out_cursor = Cursor::new(output);
208
209 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
210 .map_err(|_| DecompressionFailed)?;
211
212 if out_cursor.position() as usize != out_cursor.into_inner().len() {
213 return Err(DecompressionFailed);
214 }
215
216 Ok(())
217 }
218
219 fn algorithm(&self) -> CertificateCompressionAlgorithm {
220 CertificateCompressionAlgorithm::Brotli
221 }
222 }
223
224 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
226
227 #[derive(Debug)]
228 struct BrotliCompressor;
229
230 impl CertCompressor for BrotliCompressor {
231 fn compress(
232 &self,
233 input: Vec<u8>,
234 level: CompressionLevel,
235 ) -> Result<Vec<u8>, CompressionFailed> {
236 let quality = match level {
237 CompressionLevel::Interactive => QUALITY_FAST,
238 CompressionLevel::Amortized => QUALITY_SLOW,
239 };
240 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
241 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
242 compressor
243 .write_all(&input)
244 .map_err(|_| CompressionFailed)?;
245 Ok(compressor.into_inner().into_inner())
246 }
247
248 fn algorithm(&self) -> CertificateCompressionAlgorithm {
249 CertificateCompressionAlgorithm::Brotli
250 }
251 }
252
253 const BUFFER_SIZE: usize = 4096;
257
258 const LGWIN: u32 = 22;
260
261 const QUALITY_FAST: u32 = 4;
264
265 const QUALITY_SLOW: u32 = 11;
267}
268
269#[cfg(feature = "brotli")]
270pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
271
272#[allow(clippy::exhaustive_enums)]
278#[derive(Debug)]
279pub enum CompressionCache {
280 Disabled,
283
284 #[cfg(feature = "std")]
286 Enabled(CompressionCacheInner),
287}
288
289#[cfg(feature = "std")]
293#[derive(Debug)]
294pub struct CompressionCacheInner {
295 size: usize,
297
298 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
302}
303
304impl CompressionCache {
305 #[cfg(feature = "std")]
308 pub fn new(size: usize) -> Self {
309 if size == 0 {
310 return Self::Disabled;
311 }
312
313 Self::Enabled(CompressionCacheInner {
314 size,
315 entries: Mutex::new(VecDeque::with_capacity(size)),
316 })
317 }
318
319 pub(crate) fn compression_for(
325 &self,
326 compressor: &dyn CertCompressor,
327 original: &CertificatePayloadTls13<'_>,
328 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
329 match self {
330 Self::Disabled => Self::uncached_compression(compressor, original),
331
332 #[cfg(feature = "std")]
333 Self::Enabled(_) => self.compression_for_impl(compressor, original),
334 }
335 }
336
337 #[cfg(feature = "std")]
338 fn compression_for_impl(
339 &self,
340 compressor: &dyn CertCompressor,
341 original: &CertificatePayloadTls13<'_>,
342 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
343 let (max_size, entries) = match self {
344 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
345 _ => unreachable!(),
346 };
347
348 if !original.context.0.is_empty() {
351 return Self::uncached_compression(compressor, original);
352 }
353
354 let encoding = original.get_encoding();
356 let algorithm = compressor.algorithm();
357
358 let mut cache = entries
359 .lock()
360 .map_err(|_| CompressionFailed)?;
361 for (i, item) in cache.iter().enumerate() {
362 if item.algorithm == algorithm && item.original == encoding {
363 let item = cache.remove(i).unwrap();
365 cache.push_back(item.clone());
366 return Ok(item);
367 }
368 }
369 drop(cache);
370
371 let uncompressed_len = encoding.len() as u32;
373 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
374 let new_entry = Arc::new(CompressionCacheEntry {
375 algorithm,
376 original: encoding,
377 compressed: CompressedCertificatePayload {
378 alg: algorithm,
379 uncompressed_len,
380 compressed: PayloadU24(Payload::new(compressed)),
381 },
382 });
383
384 let mut cache = entries
386 .lock()
387 .map_err(|_| CompressionFailed)?;
388 if cache.len() == max_size {
389 cache.pop_front();
390 }
391 cache.push_back(new_entry.clone());
392 Ok(new_entry)
393 }
394
395 fn uncached_compression(
397 compressor: &dyn CertCompressor,
398 original: &CertificatePayloadTls13<'_>,
399 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
400 let algorithm = compressor.algorithm();
401 let encoding = original.get_encoding();
402 let uncompressed_len = encoding.len() as u32;
403 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
404
405 Ok(Arc::new(CompressionCacheEntry {
408 algorithm,
409 original: Vec::new(),
410 compressed: CompressedCertificatePayload {
411 alg: algorithm,
412 uncompressed_len,
413 compressed: PayloadU24(Payload::new(compressed)),
414 },
415 }))
416 }
417}
418
419impl Default for CompressionCache {
420 fn default() -> Self {
421 #[cfg(feature = "std")]
422 {
423 Self::new(4)
425 }
426
427 #[cfg(not(feature = "std"))]
428 {
429 Self::Disabled
430 }
431 }
432}
433
434#[cfg_attr(not(feature = "std"), allow(dead_code))]
435#[derive(Debug)]
436pub(crate) struct CompressionCacheEntry {
437 algorithm: CertificateCompressionAlgorithm,
439 original: Vec<u8>,
440
441 compressed: CompressedCertificatePayload<'static>,
443}
444
445impl CompressionCacheEntry {
446 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
447 self.compressed.as_borrowed()
448 }
449}
450
451#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
452mod tests {
453 use std::{println, vec};
454
455 use super::*;
456
457 #[test]
458 #[cfg(feature = "zlib")]
459 fn test_zlib() {
460 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
461 }
462
463 #[test]
464 #[cfg(feature = "brotli")]
465 fn test_brotli() {
466 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
467 }
468
469 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
470 assert_eq!(comp.algorithm(), decomp.algorithm());
471 for sz in [16, 64, 512, 2048, 8192, 16384] {
472 test_trivial_pairwise(comp, decomp, sz);
473 }
474 test_decompress_wrong_len(comp, decomp);
475 test_decompress_garbage(decomp);
476 }
477
478 fn test_trivial_pairwise(
479 comp: &dyn CertCompressor,
480 decomp: &dyn CertDecompressor,
481 plain_len: usize,
482 ) {
483 let original = vec![0u8; plain_len];
484
485 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
486 let compressed = comp
487 .compress(original.clone(), level)
488 .unwrap();
489 println!(
490 "{:?} compressed trivial {} -> {} using {:?} level",
491 comp.algorithm(),
492 original.len(),
493 compressed.len(),
494 level
495 );
496 let mut recovered = vec![0xffu8; plain_len];
497 decomp
498 .decompress(&compressed, &mut recovered)
499 .unwrap();
500 assert_eq!(original, recovered);
501 }
502 }
503
504 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
505 let original = vec![0u8; 2048];
506 let compressed = comp
507 .compress(original.clone(), CompressionLevel::Interactive)
508 .unwrap();
509 println!("{compressed:?}");
510
511 let mut recovered = vec![0xffu8; original.len() + 1];
513 decomp
514 .decompress(&compressed, &mut recovered)
515 .unwrap_err();
516
517 let mut recovered = vec![0xffu8; original.len() - 1];
519 decomp
520 .decompress(&compressed, &mut recovered)
521 .unwrap_err();
522 }
523
524 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
525 let junk = [0u8; 1024];
526 let mut recovered = vec![0u8; 512];
527 decomp
528 .decompress(&junk, &mut recovered)
529 .unwrap_err();
530 }
531
532 #[test]
533 #[cfg(all(feature = "brotli", feature = "zlib"))]
534 fn test_cache_evicts_lru() {
535 use core::sync::atomic::{AtomicBool, Ordering};
536
537 use pki_types::CertificateDer;
538
539 let cache = CompressionCache::default();
540
541 let cert = CertificateDer::from(vec![1]);
542
543 let cert1 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"1"));
544 let cert2 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"2"));
545 let cert3 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"3"));
546 let cert4 = CertificatePayloadTls13::new([&cert].into_iter(), Some(b"4"));
547
548 cache
551 .compression_for(
552 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
553 &cert1,
554 )
555 .unwrap();
556 cache
557 .compression_for(
558 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
559 &cert2,
560 )
561 .unwrap();
562 cache
563 .compression_for(
564 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
565 &cert3,
566 )
567 .unwrap();
568 cache
569 .compression_for(
570 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
571 &cert4,
572 )
573 .unwrap();
574
575 cache
579 .compression_for(
580 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
581 &cert4,
582 )
583 .unwrap();
584
585 cache
587 .compression_for(
588 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
589 &cert2,
590 )
591 .unwrap();
592 cache
593 .compression_for(
594 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
595 &cert3,
596 )
597 .unwrap();
598 cache
599 .compression_for(
600 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
601 &cert4,
602 )
603 .unwrap();
604 cache
605 .compression_for(
606 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
607 &cert4,
608 )
609 .unwrap();
610
611 cache
613 .compression_for(
614 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
615 &cert1,
616 )
617 .unwrap();
618
619 cache
622 .compression_for(
623 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
624 &cert4,
625 )
626 .unwrap();
627 cache
628 .compression_for(
629 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
630 &cert3,
631 )
632 .unwrap();
633 cache
634 .compression_for(
635 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
636 &cert1,
637 )
638 .unwrap();
639
640 cache
643 .compression_for(
644 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
645 &cert1,
646 )
647 .unwrap();
648
649 cache
651 .compression_for(
652 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
653 &cert4,
654 )
655 .unwrap();
656
657 #[derive(Debug)]
658 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
659
660 impl CertCompressor for RequireCompress {
661 fn compress(
662 &self,
663 input: Vec<u8>,
664 level: CompressionLevel,
665 ) -> Result<Vec<u8>, CompressionFailed> {
666 self.1.store(true, Ordering::SeqCst);
667 self.0.compress(input, level)
668 }
669
670 fn algorithm(&self) -> CertificateCompressionAlgorithm {
671 self.0.algorithm()
672 }
673 }
674
675 impl Drop for RequireCompress {
676 fn drop(&mut self) {
677 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
678 }
679 }
680 }
681}