1use alloc::vec::Vec;
2use core::ops::{Deref, DerefMut};
3use core::{fmt, mem};
4
5use pki_types::ServerName;
6
7use super::config::ClientConfig;
8use super::hs::{self, ClientHelloInput};
9use crate::client::EchStatus;
10use crate::common_state::{CommonState, Protocol, Side};
11use crate::conn::{ConnectionCore, UnbufferedConnectionCommon};
12#[cfg(doc)]
13use crate::crypto;
14use crate::error::Error;
15use crate::kernel::KernelConnection;
16use crate::log::trace;
17use crate::msgs::deframer::Locator;
18use crate::msgs::handshake::ClientExtensionsInput;
19use crate::suites::ExtractedSecrets;
20use crate::sync::Arc;
21use crate::unbuffered::{EncryptError, TransmitTlsData};
22
23#[cfg(feature = "std")]
24mod buffered {
25 use alloc::vec::Vec;
26 use core::fmt;
27 use core::ops::{Deref, DerefMut};
28 use std::io;
29
30 use pki_types::ServerName;
31
32 use super::{ClientConnectionData, ClientExtensionsInput};
33 use crate::KeyingMaterialExporter;
34 use crate::client::EchStatus;
35 use crate::client::config::ClientConfig;
36 use crate::common_state::Protocol;
37 use crate::conn::{ConnectionCommon, ConnectionCore};
38 use crate::error::Error;
39 use crate::suites::ExtractedSecrets;
40 use crate::sync::Arc;
41
42 pub struct ClientConnection {
44 inner: ConnectionCommon<ClientConnectionData>,
45 }
46
47 impl fmt::Debug for ClientConnection {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("ClientConnection")
50 .finish()
51 }
52 }
53
54 impl ClientConnection {
55 pub fn new(config: Arc<ClientConfig>, name: ServerName<'static>) -> Result<Self, Error> {
59 Self::new_with_alpn(config.clone(), name, config.alpn_protocols.clone())
60 }
61
62 pub fn new_with_alpn(
64 config: Arc<ClientConfig>,
65 name: ServerName<'static>,
66 alpn_protocols: Vec<Vec<u8>>,
67 ) -> Result<Self, Error> {
68 Ok(Self {
69 inner: ConnectionCommon::from(ConnectionCore::for_client(
70 config,
71 name,
72 ClientExtensionsInput::from_alpn(alpn_protocols),
73 Protocol::Tcp,
74 )?),
75 })
76 }
77
78 pub fn early_data(&mut self) -> Option<WriteEarlyData<'_>> {
97 if self
98 .inner
99 .core
100 .side
101 .early_data
102 .is_enabled()
103 {
104 Some(WriteEarlyData::new(self))
105 } else {
106 None
107 }
108 }
109
110 pub fn is_early_data_accepted(&self) -> bool {
116 self.inner.core.is_early_data_accepted()
117 }
118
119 pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
122 self.inner.dangerous_extract_secrets()
123 }
124
125 pub fn ech_status(&self) -> EchStatus {
127 self.inner.core.side.ech_status
128 }
129
130 pub fn tls13_tickets_received(&self) -> u32 {
132 self.inner.tls13_tickets_received
133 }
134
135 pub fn fips(&self) -> bool {
141 self.inner.core.common_state.fips
142 }
143
144 fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
145 self.inner
146 .core
147 .side
148 .early_data
149 .check_write(data.len())
150 .map(|sz| {
151 self.inner
152 .send_early_plaintext(&data[..sz])
153 })
154 }
155 }
156
157 impl Deref for ClientConnection {
158 type Target = ConnectionCommon<ClientConnectionData>;
159
160 fn deref(&self) -> &Self::Target {
161 &self.inner
162 }
163 }
164
165 impl DerefMut for ClientConnection {
166 fn deref_mut(&mut self) -> &mut Self::Target {
167 &mut self.inner
168 }
169 }
170
171 #[doc(hidden)]
172 impl<'a> TryFrom<&'a mut crate::Connection> for &'a mut ClientConnection {
173 type Error = ();
174
175 fn try_from(value: &'a mut crate::Connection) -> Result<Self, Self::Error> {
176 use crate::Connection::*;
177 match value {
178 Client(conn) => Ok(conn),
179 Server(_) => Err(()),
180 }
181 }
182 }
183
184 impl From<ClientConnection> for crate::Connection {
185 fn from(conn: ClientConnection) -> Self {
186 Self::Client(conn)
187 }
188 }
189
190 pub struct WriteEarlyData<'a> {
196 sess: &'a mut ClientConnection,
197 }
198
199 impl<'a> WriteEarlyData<'a> {
200 fn new(sess: &'a mut ClientConnection) -> Self {
201 WriteEarlyData { sess }
202 }
203
204 pub fn bytes_left(&self) -> usize {
207 self.sess
208 .inner
209 .core
210 .side
211 .early_data
212 .bytes_left()
213 }
214
215 pub fn exporter(&mut self) -> Result<KeyingMaterialExporter, Error> {
235 self.sess.core.early_exporter()
236 }
237 }
238
239 impl io::Write for WriteEarlyData<'_> {
240 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
241 self.sess.write_early_data(buf)
242 }
243
244 fn flush(&mut self) -> io::Result<()> {
245 Ok(())
246 }
247 }
248
249 impl super::EarlyData {
250 fn check_write(&mut self, sz: usize) -> io::Result<usize> {
251 self.check_write_opt(sz)
252 .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))
253 }
254
255 fn bytes_left(&self) -> usize {
256 self.left
257 }
258 }
259}
260
261#[cfg(feature = "std")]
262pub use buffered::{ClientConnection, WriteEarlyData};
263
264impl ConnectionCore<ClientConnectionData> {
265 pub(crate) fn for_client(
266 config: Arc<ClientConfig>,
267 name: ServerName<'static>,
268 extra_exts: ClientExtensionsInput<'static>,
269 proto: Protocol,
270 ) -> Result<Self, Error> {
271 let mut common_state = CommonState::new(Side::Client);
272 common_state.set_max_fragment_size(config.max_fragment_size)?;
273 common_state.protocol = proto;
274 common_state.enable_secret_extraction = config.enable_secret_extraction;
275 common_state.fips = config.fips();
276 let mut data = ClientConnectionData::new();
277
278 let mut cx = hs::ClientContext {
279 common: &mut common_state,
280 data: &mut data,
281 plaintext_locator: &Locator::new(&[]),
283 received_plaintext: &mut None,
284 sendable_plaintext: None,
286 };
287
288 let input = ClientHelloInput::new(name, &extra_exts, &mut cx, config)?;
289 let state = input.start_handshake(extra_exts, &mut cx)?;
290 debug_assert!(cx.received_plaintext.is_none(), "read plaintext");
291
292 Ok(Self::new(state, data, common_state))
293 }
294
295 #[cfg(feature = "std")]
296 pub(crate) fn is_early_data_accepted(&self) -> bool {
297 self.side.early_data.is_accepted()
298 }
299}
300
301pub struct UnbufferedClientConnection {
305 inner: UnbufferedConnectionCommon<ClientConnectionData>,
306}
307
308impl UnbufferedClientConnection {
309 pub fn new(config: Arc<ClientConfig>, name: ServerName<'static>) -> Result<Self, Error> {
312 Self::new_with_extensions(
313 config.clone(),
314 name,
315 ClientExtensionsInput::from_alpn(config.alpn_protocols.clone()),
316 )
317 }
318
319 pub fn new_with_alpn(
321 config: Arc<ClientConfig>,
322 name: ServerName<'static>,
323 alpn_protocols: Vec<Vec<u8>>,
324 ) -> Result<Self, Error> {
325 Self::new_with_extensions(
326 config,
327 name,
328 ClientExtensionsInput::from_alpn(alpn_protocols),
329 )
330 }
331
332 fn new_with_extensions(
333 config: Arc<ClientConfig>,
334 name: ServerName<'static>,
335 extensions: ClientExtensionsInput<'static>,
336 ) -> Result<Self, Error> {
337 Ok(Self {
338 inner: UnbufferedConnectionCommon::from(ConnectionCore::for_client(
339 config,
340 name,
341 extensions,
342 Protocol::Tcp,
343 )?),
344 })
345 }
346
347 #[deprecated = "dangerous_extract_secrets() does not support session tickets or \
350 key updates, use dangerous_into_kernel_connection() instead"]
351 pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
352 self.inner.dangerous_extract_secrets()
353 }
354
355 pub fn dangerous_into_kernel_connection(
365 self,
366 ) -> Result<(ExtractedSecrets, KernelConnection<ClientConnectionData>), Error> {
367 self.inner
368 .core
369 .dangerous_into_kernel_connection()
370 }
371
372 pub fn tls13_tickets_received(&self) -> u32 {
374 self.inner.tls13_tickets_received
375 }
376}
377
378impl Deref for UnbufferedClientConnection {
379 type Target = UnbufferedConnectionCommon<ClientConnectionData>;
380
381 fn deref(&self) -> &Self::Target {
382 &self.inner
383 }
384}
385
386impl DerefMut for UnbufferedClientConnection {
387 fn deref_mut(&mut self) -> &mut Self::Target {
388 &mut self.inner
389 }
390}
391
392impl TransmitTlsData<'_, ClientConnectionData> {
393 pub fn may_encrypt_early_data(&mut self) -> Option<MayEncryptEarlyData<'_>> {
398 if self
399 .conn
400 .core
401 .side
402 .early_data
403 .is_enabled()
404 {
405 Some(MayEncryptEarlyData { conn: self.conn })
406 } else {
407 None
408 }
409 }
410}
411
412pub struct MayEncryptEarlyData<'c> {
414 conn: &'c mut UnbufferedConnectionCommon<ClientConnectionData>,
415}
416
417impl MayEncryptEarlyData<'_> {
418 pub fn encrypt(
423 &mut self,
424 early_data: &[u8],
425 outgoing_tls: &mut [u8],
426 ) -> Result<usize, EarlyDataError> {
427 let Some(allowed) = self
428 .conn
429 .core
430 .side
431 .early_data
432 .check_write_opt(early_data.len())
433 else {
434 return Err(EarlyDataError::ExceededAllowedEarlyData);
435 };
436
437 self.conn
438 .core
439 .common_state
440 .write_plaintext(early_data[..allowed].into(), outgoing_tls)
441 .map_err(|e| e.into())
442 }
443}
444
445#[derive(Debug)]
446pub(super) struct EarlyData {
447 state: EarlyDataState,
448 left: usize,
449}
450
451impl EarlyData {
452 fn new() -> Self {
453 Self {
454 left: 0,
455 state: EarlyDataState::Disabled,
456 }
457 }
458
459 pub(super) fn is_enabled(&self) -> bool {
460 matches!(self.state, EarlyDataState::Ready | EarlyDataState::Accepted)
461 }
462
463 #[cfg(feature = "std")]
464 fn is_accepted(&self) -> bool {
465 matches!(
466 self.state,
467 EarlyDataState::Accepted | EarlyDataState::AcceptedFinished
468 )
469 }
470
471 pub(super) fn enable(&mut self, max_data: usize) {
472 assert_eq!(self.state, EarlyDataState::Disabled);
473 self.state = EarlyDataState::Ready;
474 self.left = max_data;
475 }
476
477 pub(super) fn rejected(&mut self) {
478 trace!("EarlyData rejected");
479 self.state = EarlyDataState::Rejected;
480 }
481
482 pub(super) fn accepted(&mut self) {
483 trace!("EarlyData accepted");
484 assert_eq!(self.state, EarlyDataState::Ready);
485 self.state = EarlyDataState::Accepted;
486 }
487
488 pub(super) fn finished(&mut self) {
489 trace!("EarlyData finished");
490 self.state = match self.state {
491 EarlyDataState::Accepted => EarlyDataState::AcceptedFinished,
492 _ => panic!("bad EarlyData state"),
493 }
494 }
495
496 fn check_write_opt(&mut self, sz: usize) -> Option<usize> {
497 match self.state {
498 EarlyDataState::Disabled => unreachable!(),
499 EarlyDataState::Ready | EarlyDataState::Accepted => {
500 let take = if self.left < sz {
501 mem::replace(&mut self.left, 0)
502 } else {
503 self.left -= sz;
504 sz
505 };
506
507 Some(take)
508 }
509 EarlyDataState::Rejected | EarlyDataState::AcceptedFinished => None,
510 }
511 }
512}
513
514#[derive(Debug, PartialEq)]
515enum EarlyDataState {
516 Disabled,
517 Ready,
518 Accepted,
519 AcceptedFinished,
520 Rejected,
521}
522
523#[non_exhaustive]
525#[derive(Debug)]
526pub enum EarlyDataError {
527 ExceededAllowedEarlyData,
529 Encrypt(EncryptError),
531}
532
533impl From<EncryptError> for EarlyDataError {
534 fn from(v: EncryptError) -> Self {
535 Self::Encrypt(v)
536 }
537}
538
539impl fmt::Display for EarlyDataError {
540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541 match self {
542 Self::ExceededAllowedEarlyData => f.write_str("cannot send any more early data"),
543 Self::Encrypt(e) => fmt::Display::fmt(e, f),
544 }
545 }
546}
547
548#[cfg(feature = "std")]
549impl core::error::Error for EarlyDataError {}
550
551#[derive(Debug)]
553pub struct ClientConnectionData {
554 pub(super) early_data: EarlyData,
555 pub(super) ech_status: EchStatus,
556}
557
558impl ClientConnectionData {
559 fn new() -> Self {
560 Self {
561 early_data: EarlyData::new(),
562 ech_status: EchStatus::NotOffered,
563 }
564 }
565}
566
567impl crate::conn::SideData for ClientConnectionData {}