1use alloc::boxed::Box;
2use alloc::vec::Vec;
3use core::mem;
4#[cfg(feature = "std")]
5use std::sync::{RwLock, RwLockReadGuard};
6
7use pki_types::UnixTime;
8
9use crate::Error;
10use crate::server::ProducesTickets;
11#[cfg(not(feature = "std"))]
12use crate::time_provider::TimeProvider;
13
14#[cfg(feature = "std")]
15#[derive(Debug)]
16pub(crate) struct TicketRotatorState {
17 current: Box<dyn ProducesTickets>,
18 previous: Option<Box<dyn ProducesTickets>>,
19 next_switch_time: u64,
20}
21
22#[cfg(feature = "std")]
26pub struct TicketRotator {
27 pub(crate) generator: fn() -> Result<Box<dyn ProducesTickets>, Error>,
28 lifetime: u32,
29 state: RwLock<TicketRotatorState>,
30}
31
32#[cfg(feature = "std")]
33impl TicketRotator {
34 pub fn new(
45 lifetime: u32,
46 generator: fn() -> Result<Box<dyn ProducesTickets>, Error>,
47 ) -> Result<Self, Error> {
48 Ok(Self {
49 generator,
50 lifetime,
51 state: RwLock::new(TicketRotatorState {
52 current: generator()?,
53 previous: None,
54 next_switch_time: UnixTime::now()
55 .as_secs()
56 .saturating_add(u64::from(lifetime)),
57 }),
58 })
59 }
60
61 pub(crate) fn maybe_roll(
71 &self,
72 now: UnixTime,
73 ) -> Option<RwLockReadGuard<'_, TicketRotatorState>> {
74 let now = now.as_secs();
75
76 {
79 let read = self.state.read().ok()?;
80
81 if now <= read.next_switch_time {
82 return Some(read);
83 }
84 }
85
86 let next = (self.generator)().ok()?;
89
90 let mut write = self.state.write().ok()?;
91
92 if now <= write.next_switch_time {
93 drop(write);
95
96 return self.state.read().ok();
97 }
98
99 write.previous = Some(mem::replace(&mut write.current, next));
104 write.next_switch_time = now.saturating_add(u64::from(self.lifetime));
105 drop(write);
106
107 self.state.read().ok()
108 }
109
110 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
111 pub(crate) const SIX_HOURS: u32 = 6 * 60 * 60;
112}
113
114#[cfg(feature = "std")]
115impl ProducesTickets for TicketRotator {
116 fn lifetime(&self) -> u32 {
117 self.lifetime
118 }
119
120 fn enabled(&self) -> bool {
121 true
122 }
123
124 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
125 self.maybe_roll(UnixTime::now())?
126 .current
127 .encrypt(message)
128 }
129
130 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
131 let state = self.maybe_roll(UnixTime::now())?;
132
133 state
135 .current
136 .decrypt(ciphertext)
137 .or_else(|| {
138 state
139 .previous
140 .as_ref()
141 .and_then(|previous| previous.decrypt(ciphertext))
142 })
143 }
144}
145
146#[cfg(feature = "std")]
147impl core::fmt::Debug for TicketRotator {
148 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
149 f.debug_struct("TicketRotator")
150 .finish_non_exhaustive()
151 }
152}