1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4use core::time::Duration;
5use std::sync::{RwLock, RwLockReadGuard};
6
7use pki_types::UnixTime;
8
9use crate::crypto::TicketProducer;
10use crate::error::Error;
11
12#[cfg(feature = "std")]
16pub struct TicketRotator {
17 pub(crate) generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
18 lifetime: Duration,
19 state: RwLock<TicketRotatorState>,
20}
21
22#[cfg(feature = "std")]
23impl TicketRotator {
24 pub fn new(
35 lifetime: Duration,
36 generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
37 ) -> Result<Self, Error> {
38 Ok(Self {
39 generator,
40 lifetime,
41 state: RwLock::new(TicketRotatorState {
42 current: Some(Generation {
43 producer: generator()?,
44 expires_at: UnixTime::now()
45 .as_secs()
46 .saturating_add(lifetime.as_secs()),
47 }),
48 previous: None,
49 }),
50 })
51 }
52
53 fn encrypt_at(&self, message: &[u8], now: UnixTime) -> Option<Vec<u8>> {
54 let state = self.maybe_roll(now)?;
55
56 if let Some(current) = &state.current {
59 return current.producer.encrypt(message);
60 }
61
62 let Some(prev) = &state.previous else {
64 return None;
65 };
66
67 if !prev.in_grace_period(now, self.lifetime) {
69 return None;
70 }
71
72 prev.producer.encrypt(message)
73 }
74
75 fn decrypt_at(&self, ciphertext: &[u8], now: UnixTime) -> Option<Vec<u8>> {
76 let state = self.maybe_roll(now)?;
77
78 if let Some(current) = &state.current {
81 if let Some(plain) = current.producer.decrypt(ciphertext) {
83 return Some(plain);
84 }
85 }
86
87 let Some(prev) = &state.previous else {
89 return None;
90 };
91
92 if !prev.in_grace_period(now, self.lifetime) {
94 return None;
95 }
96
97 prev.producer.decrypt(ciphertext)
98 }
99
100 pub(crate) fn maybe_roll(
110 &self,
111 now: UnixTime,
112 ) -> Option<RwLockReadGuard<'_, TicketRotatorState>> {
113 let now = now.as_secs();
114
115 {
118 let read = self.state.read().ok()?;
119 match &read.current {
120 Some(current) if now <= current.expires_at => return Some(read),
121 _ => {}
122 }
123 }
124
125 let mut write = self.state.write().ok()?;
126 if let Some(current) = &write.current {
127 if now <= current.expires_at {
128 drop(write);
130 return self.state.read().ok();
131 }
132 }
133
134 let next = (self.generator)()
137 .ok()
138 .map(|producer| Generation {
139 producer,
140 expires_at: now.saturating_add(self.lifetime.as_secs()),
141 });
142
143 let prev = mem::replace(&mut write.current, next);
148 if prev.is_some() {
149 write.previous = prev;
150 }
151 drop(write);
152
153 self.state.read().ok()
154 }
155
156 #[cfg(feature = "aws-lc-rs")]
157 pub(crate) const SIX_HOURS: Duration = Duration::from_secs(6 * 60 * 60);
158}
159
160impl TicketProducer for TicketRotator {
161 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
162 self.encrypt_at(message, UnixTime::now())
163 }
164
165 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
166 self.decrypt_at(ciphertext, UnixTime::now())
167 }
168
169 fn lifetime(&self) -> Duration {
170 self.lifetime
171 }
172}
173
174impl core::fmt::Debug for TicketRotator {
175 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
176 f.debug_struct("TicketRotator")
177 .finish_non_exhaustive()
178 }
179}
180
181#[derive(Debug)]
182pub(crate) struct TicketRotatorState {
183 current: Option<Generation>,
184 previous: Option<Generation>,
185}
186
187#[derive(Debug)]
188struct Generation {
189 producer: Box<dyn TicketProducer>,
190 expires_at: u64,
191}
192
193impl Generation {
194 fn in_grace_period(&self, now: UnixTime, lifetime: Duration) -> bool {
195 now.as_secs()
196 .saturating_sub(self.expires_at)
197 <= lifetime.as_secs()
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use core::sync::atomic::{AtomicU8, Ordering};
204 use core::time::Duration;
205
206 use pki_types::UnixTime;
207
208 use super::*;
209
210 #[test]
211 fn ticketrotator_switching_test() {
212 let t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
213 let now = UnixTime::now();
214 let cipher1 = t.encrypt(b"ticket 1").unwrap();
215 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
216 {
217 t.maybe_roll(UnixTime::since_unix_epoch(Duration::from_secs(
219 now.as_secs() + 10,
220 )));
221 }
222 let cipher2 = t.encrypt(b"ticket 2").unwrap();
223 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
224 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
225 {
226 t.maybe_roll(UnixTime::since_unix_epoch(Duration::from_secs(
228 now.as_secs() + 20,
229 )));
230 }
231 let cipher3 = t.encrypt(b"ticket 3").unwrap();
232 assert!(t.decrypt(&cipher1).is_none());
233 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
234 assert_eq!(t.decrypt(&cipher3).unwrap(), b"ticket 3");
235 }
236
237 #[test]
238 fn ticketrotator_remains_usable_over_temporary_ticketer_creation_failure() {
239 let mut t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
240 let now = UnixTime::now();
241 let cipher1 = t.encrypt(b"ticket 1").unwrap();
242 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
243 t.generator = fail_generator;
244
245 let t1 = UnixTime::since_unix_epoch(Duration::from_secs(now.as_secs() + 1));
247 drop(t.maybe_roll(t1));
248 assert!(t.encrypt_at(b"ticket 2", t1).is_some());
249
250 let t2 = UnixTime::since_unix_epoch(Duration::from_secs(now.as_secs() + 2));
252 let cipher3 = t.encrypt_at(b"ticket 3", t2).unwrap();
253 assert_eq!(t.decrypt_at(&cipher1, t2).unwrap(), b"ticket 1");
254 assert_eq!(t.decrypt_at(&cipher3, t2).unwrap(), b"ticket 3");
255
256 let t3 = UnixTime::since_unix_epoch(Duration::from_secs(now.as_secs() + 3));
257 assert_eq!(t.encrypt_at(b"ticket 4", t3), None);
258 assert_eq!(t.decrypt_at(&cipher3, t3), None);
259
260 t.generator = FakeTicketer::new;
262 let t4 = UnixTime::since_unix_epoch(Duration::from_secs(now.as_secs() + 4));
263 drop(t.maybe_roll(t4));
264
265 let t5 = UnixTime::since_unix_epoch(Duration::from_secs(now.as_secs() + 5));
266 let cipher5 = t.encrypt_at(b"ticket 5", t5).unwrap();
267 assert!(t.decrypt_at(&cipher1, t5).is_none());
268 assert!(t.decrypt_at(&cipher3, t5).is_none());
269 assert_eq!(t.decrypt_at(&cipher5, t5).unwrap(), b"ticket 5");
270
271 t.generator = fail_generator;
273 let mut write = t.state.write().unwrap();
274 write.current = None;
275 write.previous = None;
276 drop(write);
277 assert!(t.encrypt(b"ticket 6").is_none());
278 }
279
280 #[derive(Debug)]
281 struct FakeTicketer {
282 gen: u8,
283 }
284
285 impl FakeTicketer {
286 #[expect(clippy::new_ret_no_self)]
287 fn new() -> Result<Box<dyn TicketProducer>, Error> {
288 Ok(Box::new(Self {
289 gen: std::dbg!(FAKE_GEN.fetch_add(1, Ordering::SeqCst)),
290 }))
291 }
292 }
293
294 impl TicketProducer for FakeTicketer {
295 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
296 let mut v = Vec::with_capacity(1 + message.len());
297 v.push(self.gen);
298 v.extend(
299 message
300 .iter()
301 .copied()
302 .map(|b| b ^ self.gen),
303 );
304 Some(v)
305 }
306
307 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
308 if ciphertext.first()? != &self.gen {
309 return None;
310 }
311
312 Some(
313 ciphertext[1..]
314 .iter()
315 .copied()
316 .map(|b| b ^ self.gen)
317 .collect(),
318 )
319 }
320
321 fn lifetime(&self) -> Duration {
322 Duration::ZERO }
324 }
325
326 static FAKE_GEN: AtomicU8 = AtomicU8::new(0);
327
328 fn fail_generator() -> Result<Box<dyn TicketProducer>, Error> {
329 Err(Error::FailedToGetRandomBytes)
330 }
331}