1#[cfg(feature = "std")]
36use alloc::collections::VecDeque;
37use alloc::vec::Vec;
38use core::fmt::Debug;
39#[cfg(feature = "std")]
40use std::sync::Mutex;
41
42use crate::crypto::cipher::Payload;
43use crate::enums::CertificateCompressionAlgorithm;
44use crate::msgs::{CertificatePayloadTls13, Codec, CompressedCertificatePayload, SizedPayload};
45use crate::sync::Arc;
46
47pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
50 &[
51 #[cfg(feature = "brotli")]
52 BROTLI_DECOMPRESSOR,
53 #[cfg(feature = "zlib")]
54 ZLIB_DECOMPRESSOR,
55 ]
56}
57
58pub trait CertDecompressor: Debug + Send + Sync {
60 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
67
68 fn algorithm(&self) -> CertificateCompressionAlgorithm;
70}
71
72pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
75 &[
76 #[cfg(feature = "brotli")]
77 BROTLI_COMPRESSOR,
78 #[cfg(feature = "zlib")]
79 ZLIB_COMPRESSOR,
80 ]
81}
82
83pub trait CertCompressor: Debug + Send + Sync {
85 fn compress(
94 &self,
95 input: Vec<u8>,
96 level: CompressionLevel,
97 ) -> Result<Vec<u8>, CompressionFailed>;
98
99 fn algorithm(&self) -> CertificateCompressionAlgorithm;
101}
102
103#[non_exhaustive]
105#[derive(Debug, Copy, Clone, Eq, PartialEq)]
106pub enum CompressionLevel {
107 Interactive,
111
112 Amortized,
116}
117
118#[expect(clippy::exhaustive_structs)]
120#[derive(Debug)]
121pub struct DecompressionFailed;
122
123#[expect(clippy::exhaustive_structs)]
125#[derive(Debug)]
126pub struct CompressionFailed;
127
128#[cfg(feature = "zlib")]
129mod feat_zlib_rs {
130 use zlib_rs::{
131 DeflateConfig, InflateConfig, ReturnCode, compress_bound, compress_slice, decompress_slice,
132 };
133
134 use super::*;
135
136 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
138
139 #[derive(Debug)]
140 struct ZlibRsDecompressor;
141
142 impl CertDecompressor for ZlibRsDecompressor {
143 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
144 let output_len = output.len();
145 match decompress_slice(output, input, InflateConfig::default()) {
146 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
147 (_, _) => Err(DecompressionFailed),
148 }
149 }
150
151 fn algorithm(&self) -> CertificateCompressionAlgorithm {
152 CertificateCompressionAlgorithm::Zlib
153 }
154 }
155
156 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
158
159 #[derive(Debug)]
160 struct ZlibRsCompressor;
161
162 impl CertCompressor for ZlibRsCompressor {
163 fn compress(
164 &self,
165 input: Vec<u8>,
166 level: CompressionLevel,
167 ) -> Result<Vec<u8>, CompressionFailed> {
168 let mut output = alloc::vec![0u8; compress_bound(input.len())];
169 let config = match level {
170 CompressionLevel::Interactive => DeflateConfig::default(),
171 CompressionLevel::Amortized => DeflateConfig::best_compression(),
172 };
173 let (output_filled, rc) = compress_slice(&mut output, &input, config);
174 if rc != ReturnCode::Ok {
175 return Err(CompressionFailed);
176 }
177
178 let used = output_filled.len();
179 output.truncate(used);
180 Ok(output)
181 }
182
183 fn algorithm(&self) -> CertificateCompressionAlgorithm {
184 CertificateCompressionAlgorithm::Zlib
185 }
186 }
187}
188
189#[cfg(feature = "zlib")]
190pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
191
192#[cfg(feature = "brotli")]
193mod feat_brotli {
194 use std::io::{Cursor, Write};
195
196 use super::*;
197
198 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
200
201 #[derive(Debug)]
202 struct BrotliDecompressor;
203
204 impl CertDecompressor for BrotliDecompressor {
205 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
206 let mut in_cursor = Cursor::new(input);
207 let mut out_cursor = Cursor::new(output);
208
209 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
210 .map_err(|_| DecompressionFailed)?;
211
212 if out_cursor.position() as usize != out_cursor.into_inner().len() {
213 return Err(DecompressionFailed);
214 }
215
216 Ok(())
217 }
218
219 fn algorithm(&self) -> CertificateCompressionAlgorithm {
220 CertificateCompressionAlgorithm::Brotli
221 }
222 }
223
224 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
226
227 #[derive(Debug)]
228 struct BrotliCompressor;
229
230 impl CertCompressor for BrotliCompressor {
231 fn compress(
232 &self,
233 input: Vec<u8>,
234 level: CompressionLevel,
235 ) -> Result<Vec<u8>, CompressionFailed> {
236 let quality = match level {
237 CompressionLevel::Interactive => QUALITY_FAST,
238 CompressionLevel::Amortized => QUALITY_SLOW,
239 };
240 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
241 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
242 compressor
243 .write_all(&input)
244 .map_err(|_| CompressionFailed)?;
245 Ok(compressor.into_inner().into_inner())
246 }
247
248 fn algorithm(&self) -> CertificateCompressionAlgorithm {
249 CertificateCompressionAlgorithm::Brotli
250 }
251 }
252
253 const BUFFER_SIZE: usize = 4096;
257
258 const LGWIN: u32 = 22;
260
261 const QUALITY_FAST: u32 = 4;
264
265 const QUALITY_SLOW: u32 = 11;
267}
268
269#[cfg(feature = "brotli")]
270pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
271
272#[expect(clippy::exhaustive_enums)]
278#[derive(Debug)]
279pub enum CompressionCache {
280 Disabled,
283
284 #[cfg(feature = "std")]
286 Enabled(CompressionCacheInner),
287}
288
289#[cfg(feature = "std")]
293#[derive(Debug)]
294pub struct CompressionCacheInner {
295 size: usize,
297
298 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
302}
303
304impl CompressionCache {
305 #[cfg(feature = "std")]
308 pub fn new(size: usize) -> Self {
309 if size == 0 {
310 return Self::Disabled;
311 }
312
313 Self::Enabled(CompressionCacheInner {
314 size,
315 entries: Mutex::new(VecDeque::with_capacity(size)),
316 })
317 }
318
319 pub(crate) fn compression_for(
325 &self,
326 compressor: &dyn CertCompressor,
327 original: &CertificatePayloadTls13<'_>,
328 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
329 match self {
330 Self::Disabled => Self::uncached_compression(compressor, original),
331
332 #[cfg(feature = "std")]
333 Self::Enabled(_) => self.compression_for_impl(compressor, original),
334 }
335 }
336
337 #[cfg(feature = "std")]
338 fn compression_for_impl(
339 &self,
340 compressor: &dyn CertCompressor,
341 original: &CertificatePayloadTls13<'_>,
342 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
343 let (max_size, entries) = match self {
344 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
345 _ => unreachable!(),
346 };
347
348 if !original.context.is_empty() {
351 return Self::uncached_compression(compressor, original);
352 }
353
354 let encoding = original.get_encoding();
356 let algorithm = compressor.algorithm();
357
358 let mut cache = entries
359 .lock()
360 .map_err(|_| CompressionFailed)?;
361 for (i, item) in cache.iter().enumerate() {
362 if item.algorithm == algorithm && item.original == encoding {
363 let item = cache.remove(i).unwrap();
365 cache.push_back(item.clone());
366 return Ok(item);
367 }
368 }
369 drop(cache);
370
371 let uncompressed_len = encoding.len() as u32;
373 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
374 let new_entry = Arc::new(CompressionCacheEntry {
375 algorithm,
376 original: encoding,
377 compressed: CompressedCertificatePayload {
378 alg: algorithm,
379 uncompressed_len,
380 compressed: SizedPayload::from(Payload::new(compressed)),
381 },
382 });
383
384 let mut cache = entries
386 .lock()
387 .map_err(|_| CompressionFailed)?;
388 if cache.len() == max_size {
389 cache.pop_front();
390 }
391 cache.push_back(new_entry.clone());
392 Ok(new_entry)
393 }
394
395 fn uncached_compression(
397 compressor: &dyn CertCompressor,
398 original: &CertificatePayloadTls13<'_>,
399 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
400 let algorithm = compressor.algorithm();
401 let encoding = original.get_encoding();
402 let uncompressed_len = encoding.len() as u32;
403 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
404
405 Ok(Arc::new(CompressionCacheEntry {
408 algorithm,
409 original: Vec::new(),
410 compressed: CompressedCertificatePayload {
411 alg: algorithm,
412 uncompressed_len,
413 compressed: SizedPayload::from(Payload::new(compressed)),
414 },
415 }))
416 }
417}
418
419#[cfg_attr(not(feature = "std"), expect(clippy::derivable_impls))]
420impl Default for CompressionCache {
421 fn default() -> Self {
422 #[cfg(feature = "std")]
423 {
424 Self::new(4)
426 }
427
428 #[cfg(not(feature = "std"))]
429 {
430 Self::Disabled
431 }
432 }
433}
434
435#[cfg_attr(not(feature = "std"), expect(dead_code))]
436#[derive(Debug)]
437pub(crate) struct CompressionCacheEntry {
438 algorithm: CertificateCompressionAlgorithm,
440 original: Vec<u8>,
441
442 compressed: CompressedCertificatePayload<'static>,
444}
445
446impl CompressionCacheEntry {
447 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
448 self.compressed.as_borrowed()
449 }
450}
451
452#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
453mod tests {
454 use std::{println, vec};
455
456 use super::*;
457
458 #[test]
459 #[cfg(feature = "zlib")]
460 fn test_zlib() {
461 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
462 }
463
464 #[test]
465 #[cfg(feature = "brotli")]
466 fn test_brotli() {
467 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
468 }
469
470 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
471 assert_eq!(comp.algorithm(), decomp.algorithm());
472 for sz in [16, 64, 512, 2048, 8192, 16384] {
473 test_trivial_pairwise(comp, decomp, sz);
474 }
475 test_decompress_wrong_len(comp, decomp);
476 test_decompress_garbage(decomp);
477 }
478
479 fn test_trivial_pairwise(
480 comp: &dyn CertCompressor,
481 decomp: &dyn CertDecompressor,
482 plain_len: usize,
483 ) {
484 let original = vec![0u8; plain_len];
485
486 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
487 let compressed = comp
488 .compress(original.clone(), level)
489 .unwrap();
490 println!(
491 "{:?} compressed trivial {} -> {} using {:?} level",
492 comp.algorithm(),
493 original.len(),
494 compressed.len(),
495 level
496 );
497 let mut recovered = vec![0xffu8; plain_len];
498 decomp
499 .decompress(&compressed, &mut recovered)
500 .unwrap();
501 assert_eq!(original, recovered);
502 }
503 }
504
505 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
506 let original = vec![0u8; 2048];
507 let compressed = comp
508 .compress(original.clone(), CompressionLevel::Interactive)
509 .unwrap();
510 println!("{compressed:?}");
511
512 let mut recovered = vec![0xffu8; original.len() + 1];
514 decomp
515 .decompress(&compressed, &mut recovered)
516 .unwrap_err();
517
518 let mut recovered = vec![0xffu8; original.len() - 1];
520 decomp
521 .decompress(&compressed, &mut recovered)
522 .unwrap_err();
523 }
524
525 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
526 let junk = [0u8; 1024];
527 let mut recovered = vec![0u8; 512];
528 decomp
529 .decompress(&junk, &mut recovered)
530 .unwrap_err();
531 }
532
533 #[test]
534 #[cfg(all(feature = "brotli", feature = "zlib"))]
535 fn test_cache_evicts_lru() {
536 use core::sync::atomic::{AtomicBool, Ordering};
537
538 use pki_types::CertificateDer;
539
540 let cache = CompressionCache::default();
541
542 let certs = [CertificateDer::from(vec![1])].into_iter();
543
544 let cert1 = CertificatePayloadTls13::new(certs.clone(), Some(b"1"));
545 let cert2 = CertificatePayloadTls13::new(certs.clone(), Some(b"2"));
546 let cert3 = CertificatePayloadTls13::new(certs.clone(), Some(b"3"));
547 let cert4 = CertificatePayloadTls13::new(certs.clone(), Some(b"4"));
548
549 cache
552 .compression_for(
553 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
554 &cert1,
555 )
556 .unwrap();
557 cache
558 .compression_for(
559 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
560 &cert2,
561 )
562 .unwrap();
563 cache
564 .compression_for(
565 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
566 &cert3,
567 )
568 .unwrap();
569 cache
570 .compression_for(
571 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
572 &cert4,
573 )
574 .unwrap();
575
576 cache
580 .compression_for(
581 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
582 &cert4,
583 )
584 .unwrap();
585
586 cache
588 .compression_for(
589 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
590 &cert2,
591 )
592 .unwrap();
593 cache
594 .compression_for(
595 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
596 &cert3,
597 )
598 .unwrap();
599 cache
600 .compression_for(
601 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
602 &cert4,
603 )
604 .unwrap();
605 cache
606 .compression_for(
607 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
608 &cert4,
609 )
610 .unwrap();
611
612 cache
614 .compression_for(
615 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
616 &cert1,
617 )
618 .unwrap();
619
620 cache
623 .compression_for(
624 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
625 &cert4,
626 )
627 .unwrap();
628 cache
629 .compression_for(
630 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
631 &cert3,
632 )
633 .unwrap();
634 cache
635 .compression_for(
636 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
637 &cert1,
638 )
639 .unwrap();
640
641 cache
644 .compression_for(
645 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
646 &cert1,
647 )
648 .unwrap();
649
650 cache
652 .compression_for(
653 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
654 &cert4,
655 )
656 .unwrap();
657
658 #[derive(Debug)]
659 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
660
661 impl CertCompressor for RequireCompress {
662 fn compress(
663 &self,
664 input: Vec<u8>,
665 level: CompressionLevel,
666 ) -> Result<Vec<u8>, CompressionFailed> {
667 self.1.store(true, Ordering::SeqCst);
668 self.0.compress(input, level)
669 }
670
671 fn algorithm(&self) -> CertificateCompressionAlgorithm {
672 self.0.algorithm()
673 }
674 }
675
676 impl Drop for RequireCompress {
677 fn drop(&mut self) {
678 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
679 }
680 }
681 }
682}