1use alloc::collections::VecDeque;
36use alloc::vec::Vec;
37use core::fmt::Debug;
38use std::sync::Mutex;
39
40use crate::crypto::cipher::Payload;
41use crate::enums::CertificateCompressionAlgorithm;
42use crate::msgs::{CertificatePayloadTls13, Codec, CompressedCertificatePayload, SizedPayload};
43use crate::sync::Arc;
44
45pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
48 &[
49 #[cfg(feature = "brotli")]
50 BROTLI_DECOMPRESSOR,
51 #[cfg(feature = "zlib")]
52 ZLIB_DECOMPRESSOR,
53 ]
54}
55
56pub trait CertDecompressor: Debug + Send + Sync {
58 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
65
66 fn algorithm(&self) -> CertificateCompressionAlgorithm;
68}
69
70pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
73 &[
74 #[cfg(feature = "brotli")]
75 BROTLI_COMPRESSOR,
76 #[cfg(feature = "zlib")]
77 ZLIB_COMPRESSOR,
78 ]
79}
80
81pub trait CertCompressor: Debug + Send + Sync {
83 fn compress(
92 &self,
93 input: Vec<u8>,
94 level: CompressionLevel,
95 ) -> Result<Vec<u8>, CompressionFailed>;
96
97 fn algorithm(&self) -> CertificateCompressionAlgorithm;
99}
100
101#[non_exhaustive]
103#[derive(Debug, Copy, Clone, Eq, PartialEq)]
104pub enum CompressionLevel {
105 Interactive,
109
110 Amortized,
114}
115
116#[expect(clippy::exhaustive_structs)]
118#[derive(Debug)]
119pub struct DecompressionFailed;
120
121#[expect(clippy::exhaustive_structs)]
123#[derive(Debug)]
124pub struct CompressionFailed;
125
126#[cfg(feature = "zlib")]
127mod feat_zlib_rs {
128 use zlib_rs::{
129 DeflateConfig, InflateConfig, ReturnCode, compress_bound, compress_slice, decompress_slice,
130 };
131
132 use super::*;
133
134 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
136
137 #[derive(Debug)]
138 struct ZlibRsDecompressor;
139
140 impl CertDecompressor for ZlibRsDecompressor {
141 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
142 let output_len = output.len();
143 match decompress_slice(output, input, InflateConfig::default()) {
144 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
145 (_, _) => Err(DecompressionFailed),
146 }
147 }
148
149 fn algorithm(&self) -> CertificateCompressionAlgorithm {
150 CertificateCompressionAlgorithm::Zlib
151 }
152 }
153
154 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
156
157 #[derive(Debug)]
158 struct ZlibRsCompressor;
159
160 impl CertCompressor for ZlibRsCompressor {
161 fn compress(
162 &self,
163 input: Vec<u8>,
164 level: CompressionLevel,
165 ) -> Result<Vec<u8>, CompressionFailed> {
166 let mut output = alloc::vec![0u8; compress_bound(input.len())];
167 let config = match level {
168 CompressionLevel::Interactive => DeflateConfig::default(),
169 CompressionLevel::Amortized => DeflateConfig::best_compression(),
170 };
171 let (output_filled, rc) = compress_slice(&mut output, &input, config);
172 if rc != ReturnCode::Ok {
173 return Err(CompressionFailed);
174 }
175
176 let used = output_filled.len();
177 output.truncate(used);
178 Ok(output)
179 }
180
181 fn algorithm(&self) -> CertificateCompressionAlgorithm {
182 CertificateCompressionAlgorithm::Zlib
183 }
184 }
185}
186
187#[cfg(feature = "zlib")]
188pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
189
190#[allow(clippy::std_instead_of_core)] #[cfg(feature = "brotli")]
192mod feat_brotli {
193 use std::io::{Cursor, Write};
194
195 use super::*;
196
197 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
199
200 #[derive(Debug)]
201 struct BrotliDecompressor;
202
203 impl CertDecompressor for BrotliDecompressor {
204 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
205 let mut in_cursor = Cursor::new(input);
206 let mut out_cursor = Cursor::new(output);
207
208 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
209 .map_err(|_| DecompressionFailed)?;
210
211 if out_cursor.position() as usize != out_cursor.into_inner().len() {
212 return Err(DecompressionFailed);
213 }
214
215 Ok(())
216 }
217
218 fn algorithm(&self) -> CertificateCompressionAlgorithm {
219 CertificateCompressionAlgorithm::Brotli
220 }
221 }
222
223 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
225
226 #[derive(Debug)]
227 struct BrotliCompressor;
228
229 impl CertCompressor for BrotliCompressor {
230 fn compress(
231 &self,
232 input: Vec<u8>,
233 level: CompressionLevel,
234 ) -> Result<Vec<u8>, CompressionFailed> {
235 let quality = match level {
236 CompressionLevel::Interactive => QUALITY_FAST,
237 CompressionLevel::Amortized => QUALITY_SLOW,
238 };
239 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
240 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
241 compressor
242 .write_all(&input)
243 .map_err(|_| CompressionFailed)?;
244 Ok(compressor.into_inner().into_inner())
245 }
246
247 fn algorithm(&self) -> CertificateCompressionAlgorithm {
248 CertificateCompressionAlgorithm::Brotli
249 }
250 }
251
252 const BUFFER_SIZE: usize = 4096;
256
257 const LGWIN: u32 = 22;
259
260 const QUALITY_FAST: u32 = 4;
263
264 const QUALITY_SLOW: u32 = 11;
266}
267
268#[cfg(feature = "brotli")]
269pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
270
271#[expect(clippy::exhaustive_enums)]
277#[derive(Debug)]
278pub enum CompressionCache {
279 Disabled,
282
283 Enabled(CompressionCacheInner),
285}
286
287#[derive(Debug)]
291pub struct CompressionCacheInner {
292 size: usize,
294
295 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
299}
300
301impl CompressionCache {
302 pub fn new(size: usize) -> Self {
305 if size == 0 {
306 return Self::Disabled;
307 }
308
309 Self::Enabled(CompressionCacheInner {
310 size,
311 entries: Mutex::new(VecDeque::with_capacity(size)),
312 })
313 }
314
315 pub(crate) fn compression_for(
321 &self,
322 compressor: &dyn CertCompressor,
323 original: &CertificatePayloadTls13<'_>,
324 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
325 match self {
326 Self::Disabled => Self::uncached_compression(compressor, original),
327 Self::Enabled(_) => self.compression_for_impl(compressor, original),
328 }
329 }
330
331 fn compression_for_impl(
332 &self,
333 compressor: &dyn CertCompressor,
334 original: &CertificatePayloadTls13<'_>,
335 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
336 let (max_size, entries) = match self {
337 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
338 _ => unreachable!(),
339 };
340
341 if !original.context.is_empty() {
344 return Self::uncached_compression(compressor, original);
345 }
346
347 let encoding = original.get_encoding();
349 let algorithm = compressor.algorithm();
350
351 let mut cache = entries
352 .lock()
353 .map_err(|_| CompressionFailed)?;
354 for (i, item) in cache.iter().enumerate() {
355 if item.algorithm == algorithm && item.original == encoding {
356 let item = cache.remove(i).unwrap();
358 cache.push_back(item.clone());
359 return Ok(item);
360 }
361 }
362 drop(cache);
363
364 let uncompressed_len = encoding.len() as u32;
366 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
367 let new_entry = Arc::new(CompressionCacheEntry {
368 algorithm,
369 original: encoding,
370 compressed: CompressedCertificatePayload {
371 alg: algorithm,
372 uncompressed_len,
373 compressed: SizedPayload::from(Payload::new(compressed)),
374 },
375 });
376
377 let mut cache = entries
379 .lock()
380 .map_err(|_| CompressionFailed)?;
381 if cache.len() == max_size {
382 cache.pop_front();
383 }
384 cache.push_back(new_entry.clone());
385 Ok(new_entry)
386 }
387
388 fn uncached_compression(
390 compressor: &dyn CertCompressor,
391 original: &CertificatePayloadTls13<'_>,
392 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
393 let algorithm = compressor.algorithm();
394 let encoding = original.get_encoding();
395 let uncompressed_len = encoding.len() as u32;
396 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
397
398 Ok(Arc::new(CompressionCacheEntry {
401 algorithm,
402 original: Vec::new(),
403 compressed: CompressedCertificatePayload {
404 alg: algorithm,
405 uncompressed_len,
406 compressed: SizedPayload::from(Payload::new(compressed)),
407 },
408 }))
409 }
410}
411
412impl Default for CompressionCache {
413 fn default() -> Self {
414 Self::new(4)
416 }
417}
418
419#[derive(Debug)]
420pub(crate) struct CompressionCacheEntry {
421 algorithm: CertificateCompressionAlgorithm,
423 original: Vec<u8>,
424
425 compressed: CompressedCertificatePayload<'static>,
427}
428
429impl CompressionCacheEntry {
430 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
431 self.compressed.as_borrowed()
432 }
433}
434
435#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
436mod tests {
437 use std::{println, vec};
438
439 use super::*;
440
441 #[test]
442 #[cfg(feature = "zlib")]
443 fn test_zlib() {
444 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
445 }
446
447 #[test]
448 #[cfg(feature = "brotli")]
449 fn test_brotli() {
450 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
451 }
452
453 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
454 assert_eq!(comp.algorithm(), decomp.algorithm());
455 for sz in [16, 64, 512, 2048, 8192, 16384] {
456 test_trivial_pairwise(comp, decomp, sz);
457 }
458 test_decompress_wrong_len(comp, decomp);
459 test_decompress_garbage(decomp);
460 }
461
462 fn test_trivial_pairwise(
463 comp: &dyn CertCompressor,
464 decomp: &dyn CertDecompressor,
465 plain_len: usize,
466 ) {
467 let original = vec![0u8; plain_len];
468
469 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
470 let compressed = comp
471 .compress(original.clone(), level)
472 .unwrap();
473 println!(
474 "{:?} compressed trivial {} -> {} using {:?} level",
475 comp.algorithm(),
476 original.len(),
477 compressed.len(),
478 level
479 );
480 let mut recovered = vec![0xffu8; plain_len];
481 decomp
482 .decompress(&compressed, &mut recovered)
483 .unwrap();
484 assert_eq!(original, recovered);
485 }
486 }
487
488 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
489 let original = vec![0u8; 2048];
490 let compressed = comp
491 .compress(original.clone(), CompressionLevel::Interactive)
492 .unwrap();
493 println!("{compressed:?}");
494
495 let mut recovered = vec![0xffu8; original.len() + 1];
497 decomp
498 .decompress(&compressed, &mut recovered)
499 .unwrap_err();
500
501 let mut recovered = vec![0xffu8; original.len() - 1];
503 decomp
504 .decompress(&compressed, &mut recovered)
505 .unwrap_err();
506 }
507
508 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
509 let junk = [0u8; 1024];
510 let mut recovered = vec![0u8; 512];
511 decomp
512 .decompress(&junk, &mut recovered)
513 .unwrap_err();
514 }
515
516 #[test]
517 #[cfg(all(feature = "brotli", feature = "zlib"))]
518 fn test_cache_evicts_lru() {
519 use core::sync::atomic::{AtomicBool, Ordering};
520
521 use pki_types::CertificateDer;
522
523 let cache = CompressionCache::default();
524
525 let certs = [CertificateDer::from(vec![1])].into_iter();
526
527 let cert1 = CertificatePayloadTls13::new(certs.clone(), Some(b"1"));
528 let cert2 = CertificatePayloadTls13::new(certs.clone(), Some(b"2"));
529 let cert3 = CertificatePayloadTls13::new(certs.clone(), Some(b"3"));
530 let cert4 = CertificatePayloadTls13::new(certs.clone(), Some(b"4"));
531
532 cache
535 .compression_for(
536 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
537 &cert1,
538 )
539 .unwrap();
540 cache
541 .compression_for(
542 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
543 &cert2,
544 )
545 .unwrap();
546 cache
547 .compression_for(
548 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
549 &cert3,
550 )
551 .unwrap();
552 cache
553 .compression_for(
554 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
555 &cert4,
556 )
557 .unwrap();
558
559 cache
563 .compression_for(
564 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
565 &cert4,
566 )
567 .unwrap();
568
569 cache
571 .compression_for(
572 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
573 &cert2,
574 )
575 .unwrap();
576 cache
577 .compression_for(
578 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
579 &cert3,
580 )
581 .unwrap();
582 cache
583 .compression_for(
584 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
585 &cert4,
586 )
587 .unwrap();
588 cache
589 .compression_for(
590 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
591 &cert4,
592 )
593 .unwrap();
594
595 cache
597 .compression_for(
598 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
599 &cert1,
600 )
601 .unwrap();
602
603 cache
606 .compression_for(
607 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
608 &cert4,
609 )
610 .unwrap();
611 cache
612 .compression_for(
613 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
614 &cert3,
615 )
616 .unwrap();
617 cache
618 .compression_for(
619 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
620 &cert1,
621 )
622 .unwrap();
623
624 cache
627 .compression_for(
628 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
629 &cert1,
630 )
631 .unwrap();
632
633 cache
635 .compression_for(
636 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
637 &cert4,
638 )
639 .unwrap();
640
641 #[derive(Debug)]
642 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
643
644 impl CertCompressor for RequireCompress {
645 fn compress(
646 &self,
647 input: Vec<u8>,
648 level: CompressionLevel,
649 ) -> Result<Vec<u8>, CompressionFailed> {
650 self.1.store(true, Ordering::SeqCst);
651 self.0.compress(input, level)
652 }
653
654 fn algorithm(&self) -> CertificateCompressionAlgorithm {
655 self.0.algorithm()
656 }
657 }
658
659 impl Drop for RequireCompress {
660 fn drop(&mut self) {
661 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
662 }
663 }
664 }
665}