1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4use core::time::Duration;
5use std::sync::{RwLock, RwLockReadGuard};
6use std::time::Instant;
7
8use crate::crypto::TicketProducer;
9use crate::error::Error;
10
11pub struct TicketRotator {
15 pub(crate) generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
16 lifetime: Duration,
17 state: RwLock<TicketRotatorState>,
18}
19
20impl TicketRotator {
21 pub fn new(
32 lifetime: Duration,
33 generator: fn() -> Result<Box<dyn TicketProducer>, Error>,
34 ) -> Result<Self, Error> {
35 Ok(Self {
36 generator,
37 lifetime,
38 state: RwLock::new(TicketRotatorState {
39 current: Some(Generation {
40 producer: generator()?,
41 expires_at: Instant::now() + lifetime,
42 }),
43 previous: None,
44 }),
45 })
46 }
47
48 fn encrypt_at(&self, message: &[u8], now: Instant) -> Option<Vec<u8>> {
49 let state = self.maybe_roll(now)?;
50
51 if let Some(current) = &state.current {
54 return current.producer.encrypt(message);
55 }
56
57 let Some(prev) = &state.previous else {
59 return None;
60 };
61
62 if !prev.in_grace_period(now, self.lifetime) {
64 return None;
65 }
66
67 prev.producer.encrypt(message)
68 }
69
70 fn decrypt_at(&self, ciphertext: &[u8], now: Instant) -> Option<Vec<u8>> {
71 let state = self.maybe_roll(now)?;
72
73 if let Some(current) = &state.current {
76 if let Some(plain) = current.producer.decrypt(ciphertext) {
78 return Some(plain);
79 }
80 }
81
82 let Some(prev) = &state.previous else {
84 return None;
85 };
86
87 if !prev.in_grace_period(now, self.lifetime) {
89 return None;
90 }
91
92 prev.producer.decrypt(ciphertext)
93 }
94
95 pub(crate) fn maybe_roll(
105 &self,
106 now: Instant,
107 ) -> Option<RwLockReadGuard<'_, TicketRotatorState>> {
108 {
111 let read = self.state.read().ok()?;
112 match &read.current {
113 Some(current) if now <= current.expires_at => return Some(read),
114 _ => {}
115 }
116 }
117
118 let mut write = self.state.write().ok()?;
119 if let Some(current) = &write.current {
120 if now <= current.expires_at {
121 drop(write);
123 return self.state.read().ok();
124 }
125 }
126
127 let next = (self.generator)()
130 .ok()
131 .map(|producer| Generation {
132 producer,
133 expires_at: now + self.lifetime,
134 });
135
136 let prev = mem::replace(&mut write.current, next);
141 if prev.is_some() {
142 write.previous = prev;
143 }
144 drop(write);
145
146 self.state.read().ok()
147 }
148}
149
150impl TicketProducer for TicketRotator {
151 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
152 self.encrypt_at(message, Instant::now())
153 }
154
155 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
156 self.decrypt_at(ciphertext, Instant::now())
157 }
158
159 fn lifetime(&self) -> Duration {
160 self.lifetime
161 }
162}
163
164impl core::fmt::Debug for TicketRotator {
165 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
166 f.debug_struct("TicketRotator")
167 .finish_non_exhaustive()
168 }
169}
170
171#[derive(Debug)]
172pub(crate) struct TicketRotatorState {
173 current: Option<Generation>,
174 previous: Option<Generation>,
175}
176
177#[derive(Debug)]
178struct Generation {
179 producer: Box<dyn TicketProducer>,
180 expires_at: Instant,
181}
182
183impl Generation {
184 fn in_grace_period(&self, now: Instant, lifetime: Duration) -> bool {
185 now <= self.expires_at + lifetime
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use core::sync::atomic::{AtomicU8, Ordering};
192 use core::time::Duration;
193
194 use super::*;
195
196 #[test]
197 fn ticketrotator_switching_test() {
198 let t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
199 let now = Instant::now();
200 let cipher1 = t.encrypt(b"ticket 1").unwrap();
201 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
202 {
203 t.maybe_roll(now + Duration::from_secs(10));
205 }
206 let cipher2 = t.encrypt(b"ticket 2").unwrap();
207 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
208 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
209 {
210 t.maybe_roll(now + Duration::from_secs(20));
212 }
213 let cipher3 = t.encrypt(b"ticket 3").unwrap();
214 assert!(t.decrypt(&cipher1).is_none());
215 assert_eq!(t.decrypt(&cipher2).unwrap(), b"ticket 2");
216 assert_eq!(t.decrypt(&cipher3).unwrap(), b"ticket 3");
217 }
218
219 #[test]
220 fn ticketrotator_remains_usable_over_temporary_ticketer_creation_failure() {
221 let mut t = TicketRotator::new(Duration::from_secs(1), FakeTicketer::new).unwrap();
222 let expiry = t
223 .state
224 .read()
225 .unwrap()
226 .current
227 .as_ref()
228 .unwrap()
229 .expires_at;
230 let cipher1 = t.encrypt(b"ticket 1").unwrap();
231 assert_eq!(t.decrypt(&cipher1).unwrap(), b"ticket 1");
232 t.generator = fail_generator;
233
234 let t1 = expiry;
236 drop(t.maybe_roll(t1));
237 assert!(t.encrypt_at(b"ticket 2", t1).is_some());
238
239 let t2 = expiry + Duration::from_secs(1);
241 let cipher3 = t.encrypt_at(b"ticket 3", t2).unwrap();
242 assert_eq!(t.decrypt_at(&cipher1, t2).unwrap(), b"ticket 1");
243 assert_eq!(t.decrypt_at(&cipher3, t2).unwrap(), b"ticket 3");
244
245 let t3 = expiry + Duration::from_secs(2);
246 assert_eq!(t.encrypt_at(b"ticket 4", t3), None);
247 assert_eq!(t.decrypt_at(&cipher3, t3), None);
248
249 t.generator = FakeTicketer::new;
251 let t4 = expiry + Duration::from_secs(3);
252 drop(t.maybe_roll(t4));
253
254 let t5 = expiry + Duration::from_secs(4);
255 let cipher5 = t.encrypt_at(b"ticket 5", t5).unwrap();
256 assert!(t.decrypt_at(&cipher1, t5).is_none());
257 assert!(t.decrypt_at(&cipher3, t5).is_none());
258 assert_eq!(t.decrypt_at(&cipher5, t5).unwrap(), b"ticket 5");
259
260 t.generator = fail_generator;
262 let mut write = t.state.write().unwrap();
263 write.current = None;
264 write.previous = None;
265 drop(write);
266 assert!(t.encrypt(b"ticket 6").is_none());
267 }
268
269 #[derive(Debug)]
270 struct FakeTicketer {
271 gen: u8,
272 }
273
274 impl FakeTicketer {
275 #[expect(clippy::new_ret_no_self)]
276 fn new() -> Result<Box<dyn TicketProducer>, Error> {
277 Ok(Box::new(Self {
278 gen: std::dbg!(FAKE_GEN.fetch_add(1, Ordering::SeqCst)),
279 }))
280 }
281 }
282
283 impl TicketProducer for FakeTicketer {
284 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
285 let mut v = Vec::with_capacity(1 + message.len());
286 v.push(self.gen);
287 v.extend(
288 message
289 .iter()
290 .copied()
291 .map(|b| b ^ self.gen),
292 );
293 Some(v)
294 }
295
296 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
297 if ciphertext.first()? != &self.gen {
298 return None;
299 }
300
301 Some(
302 ciphertext[1..]
303 .iter()
304 .copied()
305 .map(|b| b ^ self.gen)
306 .collect(),
307 )
308 }
309
310 fn lifetime(&self) -> Duration {
311 Duration::ZERO }
313 }
314
315 static FAKE_GEN: AtomicU8 = AtomicU8::new(0);
316
317 fn fail_generator() -> Result<Box<dyn TicketProducer>, Error> {
318 Err(Error::FailedToGetRandomBytes)
319 }
320}