1use alloc::vec::Vec;
2use core::ops::{Deref, DerefMut};
3use core::{fmt, mem};
4
5use pki_types::ServerName;
6
7use super::config::ClientConfig;
8use super::hs::ClientHelloInput;
9use crate::client::EchStatus;
10use crate::common_state::{CommonState, EarlyDataEvent, Event, Output, Protocol, Side};
11use crate::conn::{ConnectionCore, UnbufferedConnectionCommon};
12#[cfg(doc)]
13use crate::crypto;
14use crate::enums::ApplicationProtocol;
15use crate::error::Error;
16use crate::kernel::KernelConnection;
17use crate::log::trace;
18use crate::msgs::ClientExtensionsInput;
19use crate::suites::ExtractedSecrets;
20use crate::sync::Arc;
21use crate::unbuffered::{EncryptError, TransmitTlsData};
22
23#[cfg(feature = "std")]
24mod buffered {
25 use alloc::vec::Vec;
26 use core::fmt;
27 use core::ops::{Deref, DerefMut};
28 use std::io;
29
30 use pki_types::ServerName;
31
32 use super::{ClientConnectionData, ClientExtensionsInput};
33 use crate::KeyingMaterialExporter;
34 use crate::client::EchStatus;
35 use crate::client::config::ClientConfig;
36 use crate::common_state::Protocol;
37 use crate::conn::{ConnectionCommon, ConnectionCore};
38 use crate::enums::ApplicationProtocol;
39 use crate::error::Error;
40 use crate::suites::ExtractedSecrets;
41 use crate::sync::Arc;
42
43 pub struct ClientConnection {
45 inner: ConnectionCommon<ClientConnectionData>,
46 }
47
48 impl fmt::Debug for ClientConnection {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("ClientConnection")
51 .finish_non_exhaustive()
52 }
53 }
54
55 impl ClientConnection {
56 pub fn early_data(&mut self) -> Option<WriteEarlyData<'_>> {
75 if self
76 .inner
77 .core
78 .side
79 .early_data
80 .is_enabled()
81 {
82 Some(WriteEarlyData::new(self))
83 } else {
84 None
85 }
86 }
87
88 pub fn is_early_data_accepted(&self) -> bool {
94 self.inner.core.is_early_data_accepted()
95 }
96
97 pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
100 self.inner.dangerous_extract_secrets()
101 }
102
103 pub fn ech_status(&self) -> EchStatus {
105 self.inner.core.side.ech_status
106 }
107
108 fn write_early_data(&mut self, data: &[u8]) -> io::Result<usize> {
109 self.inner
110 .core
111 .side
112 .early_data
113 .check_write(data.len())
114 .map(|sz| {
115 self.inner
116 .send
117 .send_early_plaintext(&data[..sz])
118 })
119 }
120 }
121
122 impl Deref for ClientConnection {
123 type Target = ConnectionCommon<ClientConnectionData>;
124
125 fn deref(&self) -> &Self::Target {
126 &self.inner
127 }
128 }
129
130 impl DerefMut for ClientConnection {
131 fn deref_mut(&mut self) -> &mut Self::Target {
132 &mut self.inner
133 }
134 }
135
136 #[doc(hidden)]
137 impl<'a> TryFrom<&'a mut crate::Connection> for &'a mut ClientConnection {
138 type Error = ();
139
140 fn try_from(value: &'a mut crate::Connection) -> Result<Self, Self::Error> {
141 use crate::Connection::*;
142 match value {
143 Client(conn) => Ok(conn),
144 Server(_) => Err(()),
145 }
146 }
147 }
148
149 impl From<ClientConnection> for crate::Connection {
150 fn from(conn: ClientConnection) -> Self {
151 Self::Client(conn)
152 }
153 }
154
155 pub struct ClientConnectionBuilder {
159 pub(crate) config: Arc<ClientConfig>,
160 pub(crate) name: ServerName<'static>,
161 pub(crate) alpn_protocols: Option<Vec<ApplicationProtocol<'static>>>,
162 }
163
164 impl ClientConnectionBuilder {
165 pub fn with_alpn(mut self, alpn_protocols: Vec<ApplicationProtocol<'static>>) -> Self {
167 self.alpn_protocols = Some(alpn_protocols);
168 self
169 }
170
171 pub fn build(self) -> Result<ClientConnection, Error> {
173 let Self {
174 config,
175 name,
176 alpn_protocols,
177 } = self;
178
179 let alpn_protocols = alpn_protocols.unwrap_or_else(|| config.alpn_protocols.clone());
180 Ok(ClientConnection {
181 inner: ConnectionCommon::from(ConnectionCore::for_client(
182 config,
183 name,
184 ClientExtensionsInput::from_alpn(alpn_protocols),
185 Protocol::Tcp,
186 )?),
187 })
188 }
189 }
190
191 pub struct WriteEarlyData<'a> {
197 sess: &'a mut ClientConnection,
198 }
199
200 impl<'a> WriteEarlyData<'a> {
201 fn new(sess: &'a mut ClientConnection) -> Self {
202 WriteEarlyData { sess }
203 }
204
205 pub fn bytes_left(&self) -> usize {
208 self.sess
209 .inner
210 .core
211 .side
212 .early_data
213 .bytes_left()
214 }
215
216 pub fn exporter(&mut self) -> Result<KeyingMaterialExporter, Error> {
236 self.sess.core.early_exporter()
237 }
238 }
239
240 impl io::Write for WriteEarlyData<'_> {
241 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
242 self.sess.write_early_data(buf)
243 }
244
245 fn flush(&mut self) -> io::Result<()> {
246 Ok(())
247 }
248 }
249
250 impl super::EarlyData {
251 fn check_write(&mut self, sz: usize) -> io::Result<usize> {
252 self.check_write_opt(sz)
253 .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))
254 }
255
256 fn bytes_left(&self) -> usize {
257 self.left
258 }
259 }
260}
261
262#[cfg(feature = "std")]
263pub use buffered::{ClientConnection, ClientConnectionBuilder, WriteEarlyData};
264
265impl ConnectionCore<ClientConnectionData> {
266 pub(crate) fn for_client(
267 config: Arc<ClientConfig>,
268 name: ServerName<'static>,
269 extra_exts: ClientExtensionsInput,
270 proto: Protocol,
271 ) -> Result<Self, Error> {
272 let mut common_state = CommonState::new(Side::Client, proto);
273 common_state
274 .send
275 .set_max_fragment_size(config.max_fragment_size)?;
276 common_state.fips = config.fips();
277 let mut data = ClientConnectionData::new(common_state);
278
279 let input = ClientHelloInput::new(name, &extra_exts, proto, &mut data, config)?;
280 let state = input.start_handshake(extra_exts, &mut data)?;
281
282 Ok(Self::new(state, data))
283 }
284
285 #[cfg(feature = "std")]
286 pub(crate) fn is_early_data_accepted(&self) -> bool {
287 self.side.early_data.is_accepted()
288 }
289}
290
291pub struct UnbufferedClientConnection {
295 inner: UnbufferedConnectionCommon<ClientConnectionData>,
296}
297
298impl UnbufferedClientConnection {
299 pub fn new(config: Arc<ClientConfig>, name: ServerName<'static>) -> Result<Self, Error> {
302 Self::new_with_extensions(
303 config.clone(),
304 name,
305 ClientExtensionsInput::from_alpn(config.alpn_protocols.clone()),
306 )
307 }
308
309 pub fn new_with_alpn(
311 config: Arc<ClientConfig>,
312 name: ServerName<'static>,
313 alpn_protocols: Vec<ApplicationProtocol<'static>>,
314 ) -> Result<Self, Error> {
315 Self::new_with_extensions(
316 config,
317 name,
318 ClientExtensionsInput::from_alpn(alpn_protocols.clone()),
319 )
320 }
321
322 fn new_with_extensions(
323 config: Arc<ClientConfig>,
324 name: ServerName<'static>,
325 extensions: ClientExtensionsInput,
326 ) -> Result<Self, Error> {
327 Ok(Self {
328 inner: UnbufferedConnectionCommon::from(ConnectionCore::for_client(
329 config,
330 name,
331 extensions,
332 Protocol::Tcp,
333 )?),
334 })
335 }
336
337 #[deprecated = "dangerous_extract_secrets() does not support session tickets or \
340 key updates, use dangerous_into_kernel_connection() instead"]
341 pub fn dangerous_extract_secrets(self) -> Result<ExtractedSecrets, Error> {
342 self.inner.dangerous_extract_secrets()
343 }
344
345 pub fn dangerous_into_kernel_connection(
355 self,
356 ) -> Result<(ExtractedSecrets, KernelConnection<ClientConnectionData>), Error> {
357 self.inner
358 .core
359 .dangerous_into_kernel_connection()
360 }
361
362 pub fn tls13_tickets_received(&self) -> u32 {
364 self.inner.tls13_tickets_received
365 }
366}
367
368impl Deref for UnbufferedClientConnection {
369 type Target = UnbufferedConnectionCommon<ClientConnectionData>;
370
371 fn deref(&self) -> &Self::Target {
372 &self.inner
373 }
374}
375
376impl DerefMut for UnbufferedClientConnection {
377 fn deref_mut(&mut self) -> &mut Self::Target {
378 &mut self.inner
379 }
380}
381
382impl TransmitTlsData<'_, ClientConnectionData> {
383 pub fn may_encrypt_early_data(&mut self) -> Option<MayEncryptEarlyData<'_>> {
388 if self
389 .conn
390 .core
391 .side
392 .early_data
393 .is_enabled()
394 {
395 Some(MayEncryptEarlyData { conn: self.conn })
396 } else {
397 None
398 }
399 }
400}
401
402pub struct MayEncryptEarlyData<'c> {
404 conn: &'c mut UnbufferedConnectionCommon<ClientConnectionData>,
405}
406
407impl MayEncryptEarlyData<'_> {
408 pub fn encrypt(
413 &mut self,
414 early_data: &[u8],
415 outgoing_tls: &mut [u8],
416 ) -> Result<usize, EarlyDataError> {
417 let Some(allowed) = self
418 .conn
419 .core
420 .side
421 .early_data
422 .check_write_opt(early_data.len())
423 else {
424 return Err(EarlyDataError::ExceededAllowedEarlyData);
425 };
426
427 self.conn
428 .core
429 .side
430 .send
431 .write_plaintext(early_data[..allowed].into(), outgoing_tls)
432 .map_err(|e| e.into())
433 }
434}
435
436#[derive(Debug)]
437pub(super) struct EarlyData {
438 state: EarlyDataState,
439 left: usize,
440}
441
442impl EarlyData {
443 fn new() -> Self {
444 Self {
445 state: EarlyDataState::Disabled,
446 left: 0,
447 }
448 }
449
450 fn is_enabled(&self) -> bool {
451 matches!(
452 self.state,
453 EarlyDataState::Ready | EarlyDataState::Sending | EarlyDataState::Accepted
454 )
455 }
456
457 #[cfg(feature = "std")]
458 fn is_accepted(&self) -> bool {
459 matches!(
460 self.state,
461 EarlyDataState::Accepted | EarlyDataState::AcceptedFinished
462 )
463 }
464
465 fn enable(&mut self, max_data: usize) {
466 assert_eq!(self.state, EarlyDataState::Disabled);
467 self.state = EarlyDataState::Ready;
468 self.left = max_data;
469 }
470
471 fn start(&mut self) {
472 assert_eq!(self.state, EarlyDataState::Ready);
473 self.state = EarlyDataState::Sending;
474 }
475
476 fn rejected(&mut self) {
477 trace!("EarlyData rejected");
478 self.state = EarlyDataState::Rejected;
479 }
480
481 fn accepted(&mut self) {
482 trace!("EarlyData accepted");
483 assert_eq!(self.state, EarlyDataState::Sending);
484 self.state = EarlyDataState::Accepted;
485 }
486
487 pub(super) fn finished(&mut self) {
488 trace!("EarlyData finished");
489 self.state = match self.state {
490 EarlyDataState::Accepted => EarlyDataState::AcceptedFinished,
491 _ => panic!("bad EarlyData state"),
492 }
493 }
494
495 fn check_write_opt(&mut self, sz: usize) -> Option<usize> {
496 match self.state {
497 EarlyDataState::Disabled => unreachable!(),
498 EarlyDataState::Ready | EarlyDataState::Sending | EarlyDataState::Accepted => {
499 let take = if self.left < sz {
500 mem::replace(&mut self.left, 0)
501 } else {
502 self.left -= sz;
503 sz
504 };
505
506 Some(take)
507 }
508 EarlyDataState::Rejected | EarlyDataState::AcceptedFinished => None,
509 }
510 }
511}
512
513#[derive(Debug, PartialEq)]
514enum EarlyDataState {
515 Disabled,
516 Ready,
517 Sending,
518 Accepted,
519 AcceptedFinished,
520 Rejected,
521}
522
523#[non_exhaustive]
525#[derive(Debug)]
526pub enum EarlyDataError {
527 ExceededAllowedEarlyData,
529 Encrypt(EncryptError),
531}
532
533impl From<EncryptError> for EarlyDataError {
534 fn from(v: EncryptError) -> Self {
535 Self::Encrypt(v)
536 }
537}
538
539impl fmt::Display for EarlyDataError {
540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541 match self {
542 Self::ExceededAllowedEarlyData => f.write_str("cannot send any more early data"),
543 Self::Encrypt(e) => fmt::Display::fmt(e, f),
544 }
545 }
546}
547
548#[cfg(feature = "std")]
549impl core::error::Error for EarlyDataError {}
550
551#[derive(Debug)]
553pub struct ClientConnectionData {
554 common: CommonState,
555 early_data: EarlyData,
556 ech_status: EchStatus,
557}
558
559impl ClientConnectionData {
560 fn new(common: CommonState) -> Self {
561 Self {
562 common,
563 early_data: EarlyData::new(),
564 ech_status: EchStatus::default(),
565 }
566 }
567}
568
569impl crate::conn::SideData for ClientConnectionData {}
570
571impl crate::conn::private::SideData for ClientConnectionData {
572 fn into_common(self) -> CommonState {
573 self.common
574 }
575}
576
577impl Output for ClientConnectionData {
578 fn emit(&mut self, ev: Event<'_>) {
579 match ev {
580 Event::EchStatus(ech) => self.ech_status = ech,
581 Event::EarlyData(EarlyDataEvent::Accepted) => self.early_data.accepted(),
582 Event::EarlyData(EarlyDataEvent::Enable(sz)) => self.early_data.enable(sz),
583 Event::EarlyData(EarlyDataEvent::Finished) => self.early_data.finished(),
584 Event::EarlyData(EarlyDataEvent::Start) => self.early_data.start(),
585 Event::EarlyData(EarlyDataEvent::Rejected) => self.early_data.rejected(),
586 _ => self.common.emit(ev),
587 }
588 }
589}
590
591impl Deref for ClientConnectionData {
592 type Target = CommonState;
593
594 fn deref(&self) -> &Self::Target {
595 &self.common
596 }
597}
598
599impl DerefMut for ClientConnectionData {
600 fn deref_mut(&mut self) -> &mut Self::Target {
601 &mut self.common
602 }
603}