1use alloc::collections::VecDeque;
36use alloc::vec::Vec;
37use core::fmt::Debug;
38use std::sync::Mutex;
39
40use crate::crypto::cipher::Payload;
41use crate::enums::CertificateCompressionAlgorithm;
42use crate::msgs::{CertificatePayloadTls13, Codec, CompressedCertificatePayload, SizedPayload};
43use crate::sync::Arc;
44
45pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
48 &[
49 #[cfg(feature = "brotli")]
50 BROTLI_DECOMPRESSOR,
51 #[cfg(feature = "zlib")]
52 ZLIB_DECOMPRESSOR,
53 ]
54}
55
56pub trait CertDecompressor: Debug + Send + Sync {
58 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
65
66 fn algorithm(&self) -> CertificateCompressionAlgorithm;
68}
69
70pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
73 &[
74 #[cfg(feature = "brotli")]
75 BROTLI_COMPRESSOR,
76 #[cfg(feature = "zlib")]
77 ZLIB_COMPRESSOR,
78 ]
79}
80
81pub trait CertCompressor: Debug + Send + Sync {
83 fn compress(
92 &self,
93 input: Vec<u8>,
94 level: CompressionLevel,
95 ) -> Result<Vec<u8>, CompressionFailed>;
96
97 fn algorithm(&self) -> CertificateCompressionAlgorithm;
99}
100
101#[non_exhaustive]
103#[derive(Debug, Copy, Clone, Eq, PartialEq)]
104pub enum CompressionLevel {
105 Interactive,
109
110 Amortized,
114}
115
116#[expect(clippy::exhaustive_structs)]
118#[derive(Debug)]
119pub struct DecompressionFailed;
120
121#[expect(clippy::exhaustive_structs)]
123#[derive(Debug)]
124pub struct CompressionFailed;
125
126#[cfg(feature = "zlib")]
127mod feat_zlib_rs {
128 use zlib_rs::{
129 DeflateConfig, InflateConfig, ReturnCode, compress_bound, compress_slice, decompress_slice,
130 };
131
132 use super::*;
133
134 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
136
137 #[derive(Debug)]
138 struct ZlibRsDecompressor;
139
140 impl CertDecompressor for ZlibRsDecompressor {
141 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
142 let output_len = output.len();
143 match decompress_slice(output, input, InflateConfig::default()) {
144 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
145 (_, _) => Err(DecompressionFailed),
146 }
147 }
148
149 fn algorithm(&self) -> CertificateCompressionAlgorithm {
150 CertificateCompressionAlgorithm::Zlib
151 }
152 }
153
154 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
156
157 #[derive(Debug)]
158 struct ZlibRsCompressor;
159
160 impl CertCompressor for ZlibRsCompressor {
161 fn compress(
162 &self,
163 input: Vec<u8>,
164 level: CompressionLevel,
165 ) -> Result<Vec<u8>, CompressionFailed> {
166 let mut output = alloc::vec![0u8; compress_bound(input.len())];
167 let config = match level {
168 CompressionLevel::Interactive => DeflateConfig::default(),
169 CompressionLevel::Amortized => DeflateConfig::best_compression(),
170 };
171 let (output_filled, rc) = compress_slice(&mut output, &input, config);
172 if rc != ReturnCode::Ok {
173 return Err(CompressionFailed);
174 }
175
176 let used = output_filled.len();
177 output.truncate(used);
178 Ok(output)
179 }
180
181 fn algorithm(&self) -> CertificateCompressionAlgorithm {
182 CertificateCompressionAlgorithm::Zlib
183 }
184 }
185}
186
187#[cfg(feature = "zlib")]
188pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
189
190#[cfg(feature = "brotli")]
191mod feat_brotli {
192 use std::io::{Cursor, Write};
193
194 use super::*;
195
196 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
198
199 #[derive(Debug)]
200 struct BrotliDecompressor;
201
202 impl CertDecompressor for BrotliDecompressor {
203 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
204 let mut in_cursor = Cursor::new(input);
205 let mut out_cursor = Cursor::new(output);
206
207 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
208 .map_err(|_| DecompressionFailed)?;
209
210 if out_cursor.position() as usize != out_cursor.into_inner().len() {
211 return Err(DecompressionFailed);
212 }
213
214 Ok(())
215 }
216
217 fn algorithm(&self) -> CertificateCompressionAlgorithm {
218 CertificateCompressionAlgorithm::Brotli
219 }
220 }
221
222 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
224
225 #[derive(Debug)]
226 struct BrotliCompressor;
227
228 impl CertCompressor for BrotliCompressor {
229 fn compress(
230 &self,
231 input: Vec<u8>,
232 level: CompressionLevel,
233 ) -> Result<Vec<u8>, CompressionFailed> {
234 let quality = match level {
235 CompressionLevel::Interactive => QUALITY_FAST,
236 CompressionLevel::Amortized => QUALITY_SLOW,
237 };
238 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
239 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
240 compressor
241 .write_all(&input)
242 .map_err(|_| CompressionFailed)?;
243 Ok(compressor.into_inner().into_inner())
244 }
245
246 fn algorithm(&self) -> CertificateCompressionAlgorithm {
247 CertificateCompressionAlgorithm::Brotli
248 }
249 }
250
251 const BUFFER_SIZE: usize = 4096;
255
256 const LGWIN: u32 = 22;
258
259 const QUALITY_FAST: u32 = 4;
262
263 const QUALITY_SLOW: u32 = 11;
265}
266
267#[cfg(feature = "brotli")]
268pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
269
270#[expect(clippy::exhaustive_enums)]
276#[derive(Debug)]
277pub enum CompressionCache {
278 Disabled,
281
282 Enabled(CompressionCacheInner),
284}
285
286#[derive(Debug)]
290pub struct CompressionCacheInner {
291 size: usize,
293
294 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
298}
299
300impl CompressionCache {
301 pub fn new(size: usize) -> Self {
304 if size == 0 {
305 return Self::Disabled;
306 }
307
308 Self::Enabled(CompressionCacheInner {
309 size,
310 entries: Mutex::new(VecDeque::with_capacity(size)),
311 })
312 }
313
314 pub(crate) fn compression_for(
320 &self,
321 compressor: &dyn CertCompressor,
322 original: &CertificatePayloadTls13<'_>,
323 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
324 match self {
325 Self::Disabled => Self::uncached_compression(compressor, original),
326 Self::Enabled(_) => self.compression_for_impl(compressor, original),
327 }
328 }
329
330 fn compression_for_impl(
331 &self,
332 compressor: &dyn CertCompressor,
333 original: &CertificatePayloadTls13<'_>,
334 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
335 let (max_size, entries) = match self {
336 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
337 _ => unreachable!(),
338 };
339
340 if !original.context.is_empty() {
343 return Self::uncached_compression(compressor, original);
344 }
345
346 let encoding = original.get_encoding();
348 let algorithm = compressor.algorithm();
349
350 let mut cache = entries
351 .lock()
352 .map_err(|_| CompressionFailed)?;
353 for (i, item) in cache.iter().enumerate() {
354 if item.algorithm == algorithm && item.original == encoding {
355 let item = cache.remove(i).unwrap();
357 cache.push_back(item.clone());
358 return Ok(item);
359 }
360 }
361 drop(cache);
362
363 let uncompressed_len = encoding.len() as u32;
365 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
366 let new_entry = Arc::new(CompressionCacheEntry {
367 algorithm,
368 original: encoding,
369 compressed: CompressedCertificatePayload {
370 alg: algorithm,
371 uncompressed_len,
372 compressed: SizedPayload::from(Payload::new(compressed)),
373 },
374 });
375
376 let mut cache = entries
378 .lock()
379 .map_err(|_| CompressionFailed)?;
380 if cache.len() == max_size {
381 cache.pop_front();
382 }
383 cache.push_back(new_entry.clone());
384 Ok(new_entry)
385 }
386
387 fn uncached_compression(
389 compressor: &dyn CertCompressor,
390 original: &CertificatePayloadTls13<'_>,
391 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
392 let algorithm = compressor.algorithm();
393 let encoding = original.get_encoding();
394 let uncompressed_len = encoding.len() as u32;
395 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
396
397 Ok(Arc::new(CompressionCacheEntry {
400 algorithm,
401 original: Vec::new(),
402 compressed: CompressedCertificatePayload {
403 alg: algorithm,
404 uncompressed_len,
405 compressed: SizedPayload::from(Payload::new(compressed)),
406 },
407 }))
408 }
409}
410
411impl Default for CompressionCache {
412 fn default() -> Self {
413 Self::new(4)
415 }
416}
417
418#[derive(Debug)]
419pub(crate) struct CompressionCacheEntry {
420 algorithm: CertificateCompressionAlgorithm,
422 original: Vec<u8>,
423
424 compressed: CompressedCertificatePayload<'static>,
426}
427
428impl CompressionCacheEntry {
429 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
430 self.compressed.as_borrowed()
431 }
432}
433
434#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
435mod tests {
436 use std::{println, vec};
437
438 use super::*;
439
440 #[test]
441 #[cfg(feature = "zlib")]
442 fn test_zlib() {
443 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
444 }
445
446 #[test]
447 #[cfg(feature = "brotli")]
448 fn test_brotli() {
449 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
450 }
451
452 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
453 assert_eq!(comp.algorithm(), decomp.algorithm());
454 for sz in [16, 64, 512, 2048, 8192, 16384] {
455 test_trivial_pairwise(comp, decomp, sz);
456 }
457 test_decompress_wrong_len(comp, decomp);
458 test_decompress_garbage(decomp);
459 }
460
461 fn test_trivial_pairwise(
462 comp: &dyn CertCompressor,
463 decomp: &dyn CertDecompressor,
464 plain_len: usize,
465 ) {
466 let original = vec![0u8; plain_len];
467
468 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
469 let compressed = comp
470 .compress(original.clone(), level)
471 .unwrap();
472 println!(
473 "{:?} compressed trivial {} -> {} using {:?} level",
474 comp.algorithm(),
475 original.len(),
476 compressed.len(),
477 level
478 );
479 let mut recovered = vec![0xffu8; plain_len];
480 decomp
481 .decompress(&compressed, &mut recovered)
482 .unwrap();
483 assert_eq!(original, recovered);
484 }
485 }
486
487 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
488 let original = vec![0u8; 2048];
489 let compressed = comp
490 .compress(original.clone(), CompressionLevel::Interactive)
491 .unwrap();
492 println!("{compressed:?}");
493
494 let mut recovered = vec![0xffu8; original.len() + 1];
496 decomp
497 .decompress(&compressed, &mut recovered)
498 .unwrap_err();
499
500 let mut recovered = vec![0xffu8; original.len() - 1];
502 decomp
503 .decompress(&compressed, &mut recovered)
504 .unwrap_err();
505 }
506
507 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
508 let junk = [0u8; 1024];
509 let mut recovered = vec![0u8; 512];
510 decomp
511 .decompress(&junk, &mut recovered)
512 .unwrap_err();
513 }
514
515 #[test]
516 #[cfg(all(feature = "brotli", feature = "zlib"))]
517 fn test_cache_evicts_lru() {
518 use core::sync::atomic::{AtomicBool, Ordering};
519
520 use pki_types::CertificateDer;
521
522 let cache = CompressionCache::default();
523
524 let certs = [CertificateDer::from(vec![1])].into_iter();
525
526 let cert1 = CertificatePayloadTls13::new(certs.clone(), Some(b"1"));
527 let cert2 = CertificatePayloadTls13::new(certs.clone(), Some(b"2"));
528 let cert3 = CertificatePayloadTls13::new(certs.clone(), Some(b"3"));
529 let cert4 = CertificatePayloadTls13::new(certs.clone(), Some(b"4"));
530
531 cache
534 .compression_for(
535 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
536 &cert1,
537 )
538 .unwrap();
539 cache
540 .compression_for(
541 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
542 &cert2,
543 )
544 .unwrap();
545 cache
546 .compression_for(
547 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
548 &cert3,
549 )
550 .unwrap();
551 cache
552 .compression_for(
553 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
554 &cert4,
555 )
556 .unwrap();
557
558 cache
562 .compression_for(
563 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
564 &cert4,
565 )
566 .unwrap();
567
568 cache
570 .compression_for(
571 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
572 &cert2,
573 )
574 .unwrap();
575 cache
576 .compression_for(
577 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
578 &cert3,
579 )
580 .unwrap();
581 cache
582 .compression_for(
583 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
584 &cert4,
585 )
586 .unwrap();
587 cache
588 .compression_for(
589 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
590 &cert4,
591 )
592 .unwrap();
593
594 cache
596 .compression_for(
597 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
598 &cert1,
599 )
600 .unwrap();
601
602 cache
605 .compression_for(
606 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
607 &cert4,
608 )
609 .unwrap();
610 cache
611 .compression_for(
612 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
613 &cert3,
614 )
615 .unwrap();
616 cache
617 .compression_for(
618 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
619 &cert1,
620 )
621 .unwrap();
622
623 cache
626 .compression_for(
627 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
628 &cert1,
629 )
630 .unwrap();
631
632 cache
634 .compression_for(
635 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
636 &cert4,
637 )
638 .unwrap();
639
640 #[derive(Debug)]
641 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
642
643 impl CertCompressor for RequireCompress {
644 fn compress(
645 &self,
646 input: Vec<u8>,
647 level: CompressionLevel,
648 ) -> Result<Vec<u8>, CompressionFailed> {
649 self.1.store(true, Ordering::SeqCst);
650 self.0.compress(input, level)
651 }
652
653 fn algorithm(&self) -> CertificateCompressionAlgorithm {
654 self.0.algorithm()
655 }
656 }
657
658 impl Drop for RequireCompress {
659 fn drop(&mut self) {
660 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
661 }
662 }
663 }
664}