1#[cfg(feature = "std")]
36use alloc::collections::VecDeque;
37use alloc::vec::Vec;
38use core::fmt::Debug;
39#[cfg(feature = "std")]
40use std::sync::Mutex;
41
42use crate::crypto::cipher::Payload;
43use crate::enums::CertificateCompressionAlgorithm;
44use crate::msgs::base::PayloadU24;
45use crate::msgs::codec::Codec;
46use crate::msgs::handshake::{CertificatePayloadTls13, CompressedCertificatePayload};
47use crate::sync::Arc;
48
49pub fn default_cert_decompressors() -> &'static [&'static dyn CertDecompressor] {
52 &[
53 #[cfg(feature = "brotli")]
54 BROTLI_DECOMPRESSOR,
55 #[cfg(feature = "zlib")]
56 ZLIB_DECOMPRESSOR,
57 ]
58}
59
60pub trait CertDecompressor: Debug + Send + Sync {
62 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed>;
69
70 fn algorithm(&self) -> CertificateCompressionAlgorithm;
72}
73
74pub fn default_cert_compressors() -> &'static [&'static dyn CertCompressor] {
77 &[
78 #[cfg(feature = "brotli")]
79 BROTLI_COMPRESSOR,
80 #[cfg(feature = "zlib")]
81 ZLIB_COMPRESSOR,
82 ]
83}
84
85pub trait CertCompressor: Debug + Send + Sync {
87 fn compress(
96 &self,
97 input: Vec<u8>,
98 level: CompressionLevel,
99 ) -> Result<Vec<u8>, CompressionFailed>;
100
101 fn algorithm(&self) -> CertificateCompressionAlgorithm;
103}
104
105#[non_exhaustive]
107#[derive(Debug, Copy, Clone, Eq, PartialEq)]
108pub enum CompressionLevel {
109 Interactive,
113
114 Amortized,
118}
119
120#[expect(clippy::exhaustive_structs)]
122#[derive(Debug)]
123pub struct DecompressionFailed;
124
125#[expect(clippy::exhaustive_structs)]
127#[derive(Debug)]
128pub struct CompressionFailed;
129
130#[cfg(feature = "zlib")]
131mod feat_zlib_rs {
132 use zlib_rs::c_api::Z_BEST_COMPRESSION;
133 use zlib_rs::{ReturnCode, deflate, inflate};
134
135 use super::*;
136
137 pub const ZLIB_DECOMPRESSOR: &dyn CertDecompressor = &ZlibRsDecompressor;
139
140 #[derive(Debug)]
141 struct ZlibRsDecompressor;
142
143 impl CertDecompressor for ZlibRsDecompressor {
144 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
145 let output_len = output.len();
146 match inflate::uncompress_slice(output, input, inflate::InflateConfig::default()) {
147 (output_filled, ReturnCode::Ok) if output_filled.len() == output_len => Ok(()),
148 (_, _) => Err(DecompressionFailed),
149 }
150 }
151
152 fn algorithm(&self) -> CertificateCompressionAlgorithm {
153 CertificateCompressionAlgorithm::Zlib
154 }
155 }
156
157 pub const ZLIB_COMPRESSOR: &dyn CertCompressor = &ZlibRsCompressor;
159
160 #[derive(Debug)]
161 struct ZlibRsCompressor;
162
163 impl CertCompressor for ZlibRsCompressor {
164 fn compress(
165 &self,
166 input: Vec<u8>,
167 level: CompressionLevel,
168 ) -> Result<Vec<u8>, CompressionFailed> {
169 let mut output = alloc::vec![0u8; deflate::compress_bound(input.len())];
170 let config = match level {
171 CompressionLevel::Interactive => deflate::DeflateConfig::default(),
172 CompressionLevel::Amortized => deflate::DeflateConfig::new(Z_BEST_COMPRESSION),
173 };
174 let (output_filled, rc) = deflate::compress_slice(&mut output, &input, config);
175 if rc != ReturnCode::Ok {
176 return Err(CompressionFailed);
177 }
178
179 let used = output_filled.len();
180 output.truncate(used);
181 Ok(output)
182 }
183
184 fn algorithm(&self) -> CertificateCompressionAlgorithm {
185 CertificateCompressionAlgorithm::Zlib
186 }
187 }
188}
189
190#[cfg(feature = "zlib")]
191pub use feat_zlib_rs::{ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR};
192
193#[cfg(feature = "brotli")]
194mod feat_brotli {
195 use std::io::{Cursor, Write};
196
197 use super::*;
198
199 pub const BROTLI_DECOMPRESSOR: &dyn CertDecompressor = &BrotliDecompressor;
201
202 #[derive(Debug)]
203 struct BrotliDecompressor;
204
205 impl CertDecompressor for BrotliDecompressor {
206 fn decompress(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecompressionFailed> {
207 let mut in_cursor = Cursor::new(input);
208 let mut out_cursor = Cursor::new(output);
209
210 brotli::BrotliDecompress(&mut in_cursor, &mut out_cursor)
211 .map_err(|_| DecompressionFailed)?;
212
213 if out_cursor.position() as usize != out_cursor.into_inner().len() {
214 return Err(DecompressionFailed);
215 }
216
217 Ok(())
218 }
219
220 fn algorithm(&self) -> CertificateCompressionAlgorithm {
221 CertificateCompressionAlgorithm::Brotli
222 }
223 }
224
225 pub const BROTLI_COMPRESSOR: &dyn CertCompressor = &BrotliCompressor;
227
228 #[derive(Debug)]
229 struct BrotliCompressor;
230
231 impl CertCompressor for BrotliCompressor {
232 fn compress(
233 &self,
234 input: Vec<u8>,
235 level: CompressionLevel,
236 ) -> Result<Vec<u8>, CompressionFailed> {
237 let quality = match level {
238 CompressionLevel::Interactive => QUALITY_FAST,
239 CompressionLevel::Amortized => QUALITY_SLOW,
240 };
241 let output = Cursor::new(Vec::with_capacity(input.len() / 2));
242 let mut compressor = brotli::CompressorWriter::new(output, BUFFER_SIZE, quality, LGWIN);
243 compressor
244 .write_all(&input)
245 .map_err(|_| CompressionFailed)?;
246 Ok(compressor.into_inner().into_inner())
247 }
248
249 fn algorithm(&self) -> CertificateCompressionAlgorithm {
250 CertificateCompressionAlgorithm::Brotli
251 }
252 }
253
254 const BUFFER_SIZE: usize = 4096;
258
259 const LGWIN: u32 = 22;
261
262 const QUALITY_FAST: u32 = 4;
265
266 const QUALITY_SLOW: u32 = 11;
268}
269
270#[cfg(feature = "brotli")]
271pub use feat_brotli::{BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR};
272
273#[expect(clippy::exhaustive_enums)]
279#[derive(Debug)]
280pub enum CompressionCache {
281 Disabled,
284
285 #[cfg(feature = "std")]
287 Enabled(CompressionCacheInner),
288}
289
290#[cfg(feature = "std")]
294#[derive(Debug)]
295pub struct CompressionCacheInner {
296 size: usize,
298
299 entries: Mutex<VecDeque<Arc<CompressionCacheEntry>>>,
303}
304
305impl CompressionCache {
306 #[cfg(feature = "std")]
309 pub fn new(size: usize) -> Self {
310 if size == 0 {
311 return Self::Disabled;
312 }
313
314 Self::Enabled(CompressionCacheInner {
315 size,
316 entries: Mutex::new(VecDeque::with_capacity(size)),
317 })
318 }
319
320 pub(crate) fn compression_for(
326 &self,
327 compressor: &dyn CertCompressor,
328 original: &CertificatePayloadTls13<'_>,
329 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
330 match self {
331 Self::Disabled => Self::uncached_compression(compressor, original),
332
333 #[cfg(feature = "std")]
334 Self::Enabled(_) => self.compression_for_impl(compressor, original),
335 }
336 }
337
338 #[cfg(feature = "std")]
339 fn compression_for_impl(
340 &self,
341 compressor: &dyn CertCompressor,
342 original: &CertificatePayloadTls13<'_>,
343 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
344 let (max_size, entries) = match self {
345 Self::Enabled(CompressionCacheInner { size, entries }) => (*size, entries),
346 _ => unreachable!(),
347 };
348
349 if !original.context.0.is_empty() {
352 return Self::uncached_compression(compressor, original);
353 }
354
355 let encoding = original.get_encoding();
357 let algorithm = compressor.algorithm();
358
359 let mut cache = entries
360 .lock()
361 .map_err(|_| CompressionFailed)?;
362 for (i, item) in cache.iter().enumerate() {
363 if item.algorithm == algorithm && item.original == encoding {
364 let item = cache.remove(i).unwrap();
366 cache.push_back(item.clone());
367 return Ok(item);
368 }
369 }
370 drop(cache);
371
372 let uncompressed_len = encoding.len() as u32;
374 let compressed = compressor.compress(encoding.clone(), CompressionLevel::Amortized)?;
375 let new_entry = Arc::new(CompressionCacheEntry {
376 algorithm,
377 original: encoding,
378 compressed: CompressedCertificatePayload {
379 alg: algorithm,
380 uncompressed_len,
381 compressed: PayloadU24::from(Payload::new(compressed)),
382 },
383 });
384
385 let mut cache = entries
387 .lock()
388 .map_err(|_| CompressionFailed)?;
389 if cache.len() == max_size {
390 cache.pop_front();
391 }
392 cache.push_back(new_entry.clone());
393 Ok(new_entry)
394 }
395
396 fn uncached_compression(
398 compressor: &dyn CertCompressor,
399 original: &CertificatePayloadTls13<'_>,
400 ) -> Result<Arc<CompressionCacheEntry>, CompressionFailed> {
401 let algorithm = compressor.algorithm();
402 let encoding = original.get_encoding();
403 let uncompressed_len = encoding.len() as u32;
404 let compressed = compressor.compress(encoding, CompressionLevel::Interactive)?;
405
406 Ok(Arc::new(CompressionCacheEntry {
409 algorithm,
410 original: Vec::new(),
411 compressed: CompressedCertificatePayload {
412 alg: algorithm,
413 uncompressed_len,
414 compressed: PayloadU24::from(Payload::new(compressed)),
415 },
416 }))
417 }
418}
419
420impl Default for CompressionCache {
421 fn default() -> Self {
422 #[cfg(feature = "std")]
423 {
424 Self::new(4)
426 }
427
428 #[cfg(not(feature = "std"))]
429 {
430 Self::Disabled
431 }
432 }
433}
434
435#[cfg_attr(not(feature = "std"), expect(dead_code))]
436#[derive(Debug)]
437pub(crate) struct CompressionCacheEntry {
438 algorithm: CertificateCompressionAlgorithm,
440 original: Vec<u8>,
441
442 compressed: CompressedCertificatePayload<'static>,
444}
445
446impl CompressionCacheEntry {
447 pub(crate) fn compressed_cert_payload(&self) -> CompressedCertificatePayload<'_> {
448 self.compressed.as_borrowed()
449 }
450}
451
452#[cfg(all(test, any(feature = "brotli", feature = "zlib")))]
453mod tests {
454 use std::{println, vec};
455
456 use super::*;
457
458 #[test]
459 #[cfg(feature = "zlib")]
460 fn test_zlib() {
461 test_compressor(ZLIB_COMPRESSOR, ZLIB_DECOMPRESSOR);
462 }
463
464 #[test]
465 #[cfg(feature = "brotli")]
466 fn test_brotli() {
467 test_compressor(BROTLI_COMPRESSOR, BROTLI_DECOMPRESSOR);
468 }
469
470 fn test_compressor(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
471 assert_eq!(comp.algorithm(), decomp.algorithm());
472 for sz in [16, 64, 512, 2048, 8192, 16384] {
473 test_trivial_pairwise(comp, decomp, sz);
474 }
475 test_decompress_wrong_len(comp, decomp);
476 test_decompress_garbage(decomp);
477 }
478
479 fn test_trivial_pairwise(
480 comp: &dyn CertCompressor,
481 decomp: &dyn CertDecompressor,
482 plain_len: usize,
483 ) {
484 let original = vec![0u8; plain_len];
485
486 for level in [CompressionLevel::Interactive, CompressionLevel::Amortized] {
487 let compressed = comp
488 .compress(original.clone(), level)
489 .unwrap();
490 println!(
491 "{:?} compressed trivial {} -> {} using {:?} level",
492 comp.algorithm(),
493 original.len(),
494 compressed.len(),
495 level
496 );
497 let mut recovered = vec![0xffu8; plain_len];
498 decomp
499 .decompress(&compressed, &mut recovered)
500 .unwrap();
501 assert_eq!(original, recovered);
502 }
503 }
504
505 fn test_decompress_wrong_len(comp: &dyn CertCompressor, decomp: &dyn CertDecompressor) {
506 let original = vec![0u8; 2048];
507 let compressed = comp
508 .compress(original.clone(), CompressionLevel::Interactive)
509 .unwrap();
510 println!("{compressed:?}");
511
512 let mut recovered = vec![0xffu8; original.len() + 1];
514 decomp
515 .decompress(&compressed, &mut recovered)
516 .unwrap_err();
517
518 let mut recovered = vec![0xffu8; original.len() - 1];
520 decomp
521 .decompress(&compressed, &mut recovered)
522 .unwrap_err();
523 }
524
525 fn test_decompress_garbage(decomp: &dyn CertDecompressor) {
526 let junk = [0u8; 1024];
527 let mut recovered = vec![0u8; 512];
528 decomp
529 .decompress(&junk, &mut recovered)
530 .unwrap_err();
531 }
532
533 #[test]
534 #[cfg(all(feature = "brotli", feature = "zlib"))]
535 fn test_cache_evicts_lru() {
536 use core::sync::atomic::{AtomicBool, Ordering};
537
538 use pki_types::CertificateDer;
539
540 let cache = CompressionCache::default();
541
542 let certs = [CertificateDer::from(vec![1])].into_iter();
543
544 let cert1 = CertificatePayloadTls13::new(certs.clone(), Some(b"1"));
545 let cert2 = CertificatePayloadTls13::new(certs.clone(), Some(b"2"));
546 let cert3 = CertificatePayloadTls13::new(certs.clone(), Some(b"3"));
547 let cert4 = CertificatePayloadTls13::new(certs.clone(), Some(b"4"));
548
549 cache
552 .compression_for(
553 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
554 &cert1,
555 )
556 .unwrap();
557 cache
558 .compression_for(
559 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
560 &cert2,
561 )
562 .unwrap();
563 cache
564 .compression_for(
565 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
566 &cert3,
567 )
568 .unwrap();
569 cache
570 .compression_for(
571 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
572 &cert4,
573 )
574 .unwrap();
575
576 cache
580 .compression_for(
581 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
582 &cert4,
583 )
584 .unwrap();
585
586 cache
588 .compression_for(
589 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
590 &cert2,
591 )
592 .unwrap();
593 cache
594 .compression_for(
595 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
596 &cert3,
597 )
598 .unwrap();
599 cache
600 .compression_for(
601 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
602 &cert4,
603 )
604 .unwrap();
605 cache
606 .compression_for(
607 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), false),
608 &cert4,
609 )
610 .unwrap();
611
612 cache
614 .compression_for(
615 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), true),
616 &cert1,
617 )
618 .unwrap();
619
620 cache
623 .compression_for(
624 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
625 &cert4,
626 )
627 .unwrap();
628 cache
629 .compression_for(
630 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
631 &cert3,
632 )
633 .unwrap();
634 cache
635 .compression_for(
636 &RequireCompress(ZLIB_COMPRESSOR, AtomicBool::default(), false),
637 &cert1,
638 )
639 .unwrap();
640
641 cache
644 .compression_for(
645 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
646 &cert1,
647 )
648 .unwrap();
649
650 cache
652 .compression_for(
653 &RequireCompress(BROTLI_COMPRESSOR, AtomicBool::default(), true),
654 &cert4,
655 )
656 .unwrap();
657
658 #[derive(Debug)]
659 struct RequireCompress(&'static dyn CertCompressor, AtomicBool, bool);
660
661 impl CertCompressor for RequireCompress {
662 fn compress(
663 &self,
664 input: Vec<u8>,
665 level: CompressionLevel,
666 ) -> Result<Vec<u8>, CompressionFailed> {
667 self.1.store(true, Ordering::SeqCst);
668 self.0.compress(input, level)
669 }
670
671 fn algorithm(&self) -> CertificateCompressionAlgorithm {
672 self.0.algorithm()
673 }
674 }
675
676 impl Drop for RequireCompress {
677 fn drop(&mut self) {
678 assert_eq!(self.1.load(Ordering::SeqCst), self.2);
679 }
680 }
681 }
682}