1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4use core::time::Duration;
5use std::sync::{RwLock, RwLockReadGuard};
6
7use pki_types::UnixTime;
8
9use crate::crypto::TicketProducer;
10use crate::error::Error;
11
12#[cfg(feature = "std")]
16pub struct TicketRotator {
17 pub(crate) generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
18 lifetime: Duration,
19 state: RwLock<TicketRotatorState>,
20}
21
22#[cfg(feature = "std")]
23impl TicketRotator {
24 pub fn new(
35 lifetime: Duration,
36 generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
37 ) -> Result<Self, Error> {
38 Ok(Self {
39 generator,
40 lifetime,
41 state: RwLock::new(TicketRotatorState {
42 current: Some(Generation {
43 producer: generator()?,
44 expires_at: UnixTime::now()
45 .as_secs()
46 .saturating_add(lifetime.as_secs()),
47 }),
48 previous: None,
49 }),
50 })
51 }
52
53 fn encrypt_at(&self, message: &[u8], now: UnixTime) -> Option<Vec<u8>> {
54 let state = self.maybe_roll(now)?;
55
56 if let Some(current) = &state.current {
59 return current.producer.encrypt(message);
60 }
61
62 let Some(prev) = &state.previous else {
64 return None;
65 };
66
67 if !prev.in_grace_period(now, self.lifetime) {
69 return None;
70 }
71
72 prev.producer.encrypt(message)
73 }
74
75 fn decrypt_at(&self, ciphertext: &[u8], now: UnixTime) -> Option<Vec<u8>> {
76 let state = self.maybe_roll(now)?;
77
78 if let Some(current) = &state.current {
81 if let Some(plain) = current.producer.decrypt(ciphertext) {
83 return Some(plain);
84 }
85 }
86
87 let Some(prev) = &state.previous else {
89 return None;
90 };
91
92 if !prev.in_grace_period(now, self.lifetime) {
94 return None;
95 }
96
97 prev.producer.decrypt(ciphertext)
98 }
99
100 pub(crate) fn maybe_roll(
110 &self,
111 now: UnixTime,
112 ) -> Option<RwLockReadGuard<'_, TicketRotatorState>> {
113 let now = now.as_secs();
114
115 {
118 let read = self.state.read().ok()?;
119 match &read.current {
120 Some(current) if now <= current.expires_at => return Some(read),
121 _ => {}
122 }
123 }
124
125 let mut write = self.state.write().ok()?;
126 if let Some(current) = &write.current {
127 if now <= current.expires_at {
128 drop(write);
130 return self.state.read().ok();
131 }
132 }
133
134 let next = (self.generator)()
137 .ok()
138 .map(|producer| Generation {
139 producer,
140 expires_at: now.saturating_add(self.lifetime.as_secs()),
141 });
142
143 let prev = mem::replace(&mut write.current, next);
148 if prev.is_some() {
149 write.previous = prev;
150 }
151 drop(write);
152
153 self.state.read().ok()
154 }
155}
156
157impl TicketProducer for TicketRotator {
158 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
159 self.encrypt_at(message, UnixTime::now())
160 }
161
162 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
163 self.decrypt_at(ciphertext, UnixTime::now())
164 }
165
166 fn lifetime(&self) -> Duration {
167 self.lifetime
168 }
169}
170
171impl core::fmt::Debug for TicketRotator {
172 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
173 f.debug_struct("TicketRotator")
174 .finish_non_exhaustive()
175 }
176}
177
178#[derive(Debug)]
179pub(crate) struct TicketRotatorState {
180 current: Option<Generation>,
181 previous: Option<Generation>,
182}
183
184#[derive(Debug)]
185struct Generation {
186 producer: Box<dyn TicketProducer>,
187 expires_at: u64,
188}
189
190impl Generation {
191 fn in_grace_period(&self, now: UnixTime, lifetime: Duration) -> bool {
192 now.as_secs()
193 .saturating_sub(self.expires_at)
194 <= lifetime.as_secs()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use core::sync::atomic::{AtomicU8, Ordering};
201 use core::time::Duration;
202
203 use pki_types::UnixTime;
204
205 use super::*;
206
207 #[test]
208 fn ticketrotator_switching_test() {
209 let t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
210 let now = UnixTime::now();
211 let cipher1 = t.encrypt(b"ticket 1").unwrap();
212 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
213 {
214 t.maybe_roll(UnixTime::since_unix_epoch(Duration::from_secs(
216 now.as_secs() + 10,
217 )));
218 }
219 let cipher2 = t.encrypt(b"ticket 2").unwrap();
220 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
221 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
222 {
223 t.maybe_roll(UnixTime::since_unix_epoch(Duration::from_secs(
225 now.as_secs() + 20,
226 )));
227 }
228 let cipher3 = t.encrypt(b"ticket 3").unwrap();
229 assert!(t.decrypt(&cipher1).is_none());
230 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
231 assert_eq!(t.decrypt(&cipher3).unwrap(), b"ticket 3");
232 }
233
234 #[test]
235 fn ticketrotator_remains_usable_over_temporary_ticketer_creation_failure() {
236 let mut t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
237 let expiry = t
238 .state
239 .read()
240 .unwrap()
241 .current
242 .as_ref()
243 .unwrap()
244 .expires_at;
245 let cipher1 = t.encrypt(b"ticket 1").unwrap();
246 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
247 t.generator = fail_generator;
248
249 let t1 = UnixTime::since_unix_epoch(Duration::from_secs(expiry));
251 drop(t.maybe_roll(t1));
252 assert!(t.encrypt_at(b"ticket 2", t1).is_some());
253
254 let t2 = UnixTime::since_unix_epoch(Duration::from_secs(expiry + 1));
256 let cipher3 = t.encrypt_at(b"ticket 3", t2).unwrap();
257 assert_eq!(t.decrypt_at(&cipher1, t2).unwrap(), b"ticket 1");
258 assert_eq!(t.decrypt_at(&cipher3, t2).unwrap(), b"ticket 3");
259
260 let t3 = UnixTime::since_unix_epoch(Duration::from_secs(expiry + 2));
261 assert_eq!(t.encrypt_at(b"ticket 4", t3), None);
262 assert_eq!(t.decrypt_at(&cipher3, t3), None);
263
264 t.generator = FakeTicketer::new;
266 let t4 = UnixTime::since_unix_epoch(Duration::from_secs(expiry + 3));
267 drop(t.maybe_roll(t4));
268
269 let t5 = UnixTime::since_unix_epoch(Duration::from_secs(expiry + 4));
270 let cipher5 = t.encrypt_at(b"ticket 5", t5).unwrap();
271 assert!(t.decrypt_at(&cipher1, t5).is_none());
272 assert!(t.decrypt_at(&cipher3, t5).is_none());
273 assert_eq!(t.decrypt_at(&cipher5, t5).unwrap(), b"ticket 5");
274
275 t.generator = fail_generator;
277 let mut write = t.state.write().unwrap();
278 write.current = None;
279 write.previous = None;
280 drop(write);
281 assert!(t.encrypt(b"ticket 6").is_none());
282 }
283
284 #[derive(Debug)]
285 struct FakeTicketer {
286 gen: u8,
287 }
288
289 impl FakeTicketer {
290 #[expect(clippy::new_ret_no_self)]
291 fn new() -> Result<Box<dyn TicketProducer>, Error> {
292 Ok(Box::new(Self {
293 gen: std::dbg!(FAKE_GEN.fetch_add(1, Ordering::SeqCst)),
294 }))
295 }
296 }
297
298 impl TicketProducer for FakeTicketer {
299 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
300 let mut v = Vec::with_capacity(1 + message.len());
301 v.push(self.gen);
302 v.extend(
303 message
304 .iter()
305 .copied()
306 .map(|b| b ^ self.gen),
307 );
308 Some(v)
309 }
310
311 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
312 if ciphertext.first()? != &self.gen {
313 return None;
314 }
315
316 Some(
317 ciphertext[1..]
318 .iter()
319 .copied()
320 .map(|b| b ^ self.gen)
321 .collect(),
322 )
323 }
324
325 fn lifetime(&self) -> Duration {
326 Duration::ZERO }
328 }
329
330 static FAKE_GEN: AtomicU8 = AtomicU8::new(0);
331
332 fn fail_generator() -> Result<Box<dyn TicketProducer>, Error> {
333 Err(Error::FailedToGetRandomBytes)
334 }
335}