1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4use std::sync::{RwLock, RwLockReadGuard};
5
6use pki_types::UnixTime;
7
8use crate::Error;
9use crate::server::ProducesTickets;
10
11#[derive(Debug)]
12pub(crate) struct TicketRotatorState {
13 current: Box<dyn ProducesTickets>,
14 previous: Option<Box<dyn ProducesTickets>>,
15 next_switch_time: u64,
16}
17
18#[cfg(feature = "std")]
22pub struct TicketRotator {
23 pub(crate) generator: fn() -> Result<Box<dyn ProducesTickets>, Error>,
24 lifetime: u32,
25 state: RwLock<TicketRotatorState>,
26}
27
28#[cfg(feature = "std")]
29impl TicketRotator {
30 pub fn new(
41 lifetime: u32,
42 generator: fn() -> Result<Box<dyn ProducesTickets>, Error>,
43 ) -> Result<Self, Error> {
44 Ok(Self {
45 generator,
46 lifetime,
47 state: RwLock::new(TicketRotatorState {
48 current: generator()?,
49 previous: None,
50 next_switch_time: UnixTime::now()
51 .as_secs()
52 .saturating_add(u64::from(lifetime)),
53 }),
54 })
55 }
56
57 pub(crate) fn maybe_roll(
67 &self,
68 now: UnixTime,
69 ) -> Option<RwLockReadGuard<'_, TicketRotatorState>> {
70 let now = now.as_secs();
71
72 {
75 let read = self.state.read().ok()?;
76
77 if now <= read.next_switch_time {
78 return Some(read);
79 }
80 }
81
82 let next = (self.generator)().ok()?;
85
86 let mut write = self.state.write().ok()?;
87
88 if now <= write.next_switch_time {
89 drop(write);
91
92 return self.state.read().ok();
93 }
94
95 write.previous = Some(mem::replace(&mut write.current, next));
100 write.next_switch_time = now.saturating_add(u64::from(self.lifetime));
101 drop(write);
102
103 self.state.read().ok()
104 }
105
106 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
107 pub(crate) const SIX_HOURS: u32 = 6 * 60 * 60;
108}
109
110impl ProducesTickets for TicketRotator {
111 fn lifetime(&self) -> u32 {
112 self.lifetime
113 }
114
115 fn enabled(&self) -> bool {
116 true
117 }
118
119 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
120 self.maybe_roll(UnixTime::now())?
121 .current
122 .encrypt(message)
123 }
124
125 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
126 let state = self.maybe_roll(UnixTime::now())?;
127
128 state
130 .current
131 .decrypt(ciphertext)
132 .or_else(|| {
133 state
134 .previous
135 .as_ref()
136 .and_then(|previous| previous.decrypt(ciphertext))
137 })
138 }
139}
140
141impl core::fmt::Debug for TicketRotator {
142 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
143 f.debug_struct("TicketRotator")
144 .finish_non_exhaustive()
145 }
146}