rustls/msgs/message/
outbound.rs

1use alloc::vec::Vec;
2
3use super::{HEADER_SIZE, MAX_PAYLOAD, MessageError, PlainMessage};
4use crate::enums::{ContentType, ProtocolVersion};
5use crate::msgs::base::Payload;
6use crate::msgs::codec::{Codec, Reader};
7use crate::record_layer::RecordLayer;
8
9/// A TLS frame, named `TLSPlaintext` in the standard.
10///
11/// This outbound type borrows its "to be encrypted" payload from the "user".
12/// It is used for fragmenting and is consumed by encryption.
13#[allow(clippy::exhaustive_structs)]
14#[derive(Debug)]
15pub struct OutboundPlainMessage<'a> {
16    pub typ: ContentType,
17    pub version: ProtocolVersion,
18    pub payload: OutboundChunks<'a>,
19}
20
21impl OutboundPlainMessage<'_> {
22    pub(crate) fn encoded_len(&self, record_layer: &RecordLayer) -> usize {
23        HEADER_SIZE + record_layer.encrypted_len(self.payload.len())
24    }
25
26    pub(crate) fn to_unencrypted_opaque(&self) -> OutboundOpaqueMessage {
27        let mut payload = PrefixedPayload::with_capacity(self.payload.len());
28        payload.extend_from_chunks(&self.payload);
29        OutboundOpaqueMessage {
30            version: self.version,
31            typ: self.typ,
32            payload,
33        }
34    }
35}
36
37/// A collection of borrowed plaintext slices.
38///
39/// Warning: OutboundChunks does not guarantee that the simplest variant is used.
40/// Multiple can hold non fragmented or empty payloads.
41#[non_exhaustive]
42#[derive(Debug, Clone)]
43pub enum OutboundChunks<'a> {
44    /// A single byte slice. Contrary to `Multiple`, this uses a single pointer indirection
45    Single(&'a [u8]),
46    /// A collection of chunks (byte slices)
47    /// and cursors to single out a fragmented range of bytes.
48    /// OutboundChunks assumes that start <= end
49    Multiple {
50        chunks: &'a [&'a [u8]],
51        start: usize,
52        end: usize,
53    },
54}
55
56impl<'a> OutboundChunks<'a> {
57    /// Create a payload from a slice of byte slices.
58    /// If fragmented the cursors are added by default: start = 0, end = length
59    pub fn new(chunks: &'a [&'a [u8]]) -> Self {
60        if chunks.len() == 1 {
61            Self::Single(chunks[0])
62        } else {
63            Self::Multiple {
64                chunks,
65                start: 0,
66                end: chunks
67                    .iter()
68                    .map(|chunk| chunk.len())
69                    .sum(),
70            }
71        }
72    }
73
74    /// Create a payload with a single empty slice
75    pub fn new_empty() -> Self {
76        Self::Single(&[])
77    }
78
79    /// Flatten the slice of byte slices to an owned vector of bytes
80    pub fn to_vec(&self) -> Vec<u8> {
81        let mut vec = Vec::with_capacity(self.len());
82        self.copy_to_vec(&mut vec);
83        vec
84    }
85
86    /// Append all bytes to a vector
87    pub fn copy_to_vec(&self, vec: &mut Vec<u8>) {
88        match *self {
89            Self::Single(chunk) => vec.extend_from_slice(chunk),
90            Self::Multiple { chunks, start, end } => {
91                let mut size = 0;
92                for chunk in chunks.iter() {
93                    let psize = size;
94                    let len = chunk.len();
95                    size += len;
96                    if size <= start || psize >= end {
97                        continue;
98                    }
99                    let start = start.saturating_sub(psize);
100                    let end = if end - psize < len { end - psize } else { len };
101                    vec.extend_from_slice(&chunk[start..end]);
102                }
103            }
104        }
105    }
106
107    /// Split self in two, around an index
108    /// Works similarly to `split_at` in the core library, except it doesn't panic if out of bound
109    pub fn split_at(&self, mid: usize) -> (Self, Self) {
110        match *self {
111            Self::Single(chunk) => {
112                let mid = Ord::min(mid, chunk.len());
113                (Self::Single(&chunk[..mid]), Self::Single(&chunk[mid..]))
114            }
115            Self::Multiple { chunks, start, end } => {
116                let mid = Ord::min(start + mid, end);
117                (
118                    Self::Multiple {
119                        chunks,
120                        start,
121                        end: mid,
122                    },
123                    Self::Multiple {
124                        chunks,
125                        start: mid,
126                        end,
127                    },
128                )
129            }
130        }
131    }
132
133    /// Returns true if the payload is empty
134    pub fn is_empty(&self) -> bool {
135        self.len() == 0
136    }
137
138    /// Returns the cumulative length of all chunks
139    pub fn len(&self) -> usize {
140        match self {
141            Self::Single(chunk) => chunk.len(),
142            Self::Multiple { start, end, .. } => end - start,
143        }
144    }
145}
146
147impl<'a> From<&'a [u8]> for OutboundChunks<'a> {
148    fn from(payload: &'a [u8]) -> Self {
149        Self::Single(payload)
150    }
151}
152
153/// A TLS frame, named `TLSPlaintext` in the standard.
154///
155/// This outbound type owns all memory for its interior parts.
156/// It results from encryption and is used for io write.
157#[allow(clippy::exhaustive_structs)]
158#[derive(Clone, Debug)]
159pub struct OutboundOpaqueMessage {
160    pub typ: ContentType,
161    pub version: ProtocolVersion,
162    pub payload: PrefixedPayload,
163}
164
165impl OutboundOpaqueMessage {
166    /// Construct a new `OpaqueMessage` from constituent fields.
167    ///
168    /// `body` is moved into the `payload` field.
169    pub fn new(typ: ContentType, version: ProtocolVersion, payload: PrefixedPayload) -> Self {
170        Self {
171            typ,
172            version,
173            payload,
174        }
175    }
176
177    /// Construct by decoding from a [`Reader`].
178    ///
179    /// `MessageError` allows callers to distinguish between valid prefixes (might
180    /// become valid if we read more data) and invalid data.
181    pub fn read(r: &mut Reader<'_>) -> Result<Self, MessageError> {
182        let (typ, version, len) = read_opaque_message_header(r)?;
183
184        let content = r
185            .take(len as usize)
186            .ok_or(MessageError::TooShortForLength)?;
187
188        Ok(Self {
189            typ,
190            version,
191            payload: PrefixedPayload::from(content),
192        })
193    }
194
195    pub fn encode(self) -> Vec<u8> {
196        let length = self.payload.len() as u16;
197        let mut encoded_payload = self.payload.0;
198        encoded_payload[0] = self.typ.into();
199        encoded_payload[1..3].copy_from_slice(&self.version.to_array());
200        encoded_payload[3..5].copy_from_slice(&(length).to_be_bytes());
201        encoded_payload
202    }
203
204    /// Force conversion into a plaintext message.
205    ///
206    /// This should only be used for messages that are known to be in plaintext. Otherwise, the
207    /// `OutboundOpaqueMessage` should be decrypted into a `PlainMessage` using a `MessageDecrypter`.
208    pub fn into_plain_message(self) -> PlainMessage {
209        PlainMessage {
210            version: self.version,
211            typ: self.typ,
212            payload: Payload::Owned(self.payload.as_ref().to_vec()),
213        }
214    }
215}
216
217#[derive(Clone, Debug)]
218pub struct PrefixedPayload(Vec<u8>);
219
220impl PrefixedPayload {
221    pub fn with_capacity(capacity: usize) -> Self {
222        let mut prefixed_payload = Vec::with_capacity(HEADER_SIZE + capacity);
223        prefixed_payload.resize(HEADER_SIZE, 0);
224        Self(prefixed_payload)
225    }
226
227    pub fn extend_from_slice(&mut self, slice: &[u8]) {
228        self.0.extend_from_slice(slice)
229    }
230
231    pub fn extend_from_chunks(&mut self, chunks: &OutboundChunks<'_>) {
232        chunks.copy_to_vec(&mut self.0)
233    }
234
235    pub fn truncate(&mut self, len: usize) {
236        self.0.truncate(len + HEADER_SIZE)
237    }
238
239    fn len(&self) -> usize {
240        self.0.len() - HEADER_SIZE
241    }
242}
243
244impl AsRef<[u8]> for PrefixedPayload {
245    fn as_ref(&self) -> &[u8] {
246        &self.0[HEADER_SIZE..]
247    }
248}
249
250impl AsMut<[u8]> for PrefixedPayload {
251    fn as_mut(&mut self) -> &mut [u8] {
252        &mut self.0[HEADER_SIZE..]
253    }
254}
255
256impl<'a> Extend<&'a u8> for PrefixedPayload {
257    fn extend<T: IntoIterator<Item = &'a u8>>(&mut self, iter: T) {
258        self.0.extend(iter)
259    }
260}
261
262impl From<&[u8]> for PrefixedPayload {
263    fn from(content: &[u8]) -> Self {
264        let mut payload = Vec::with_capacity(HEADER_SIZE + content.len());
265        payload.extend(&[0u8; HEADER_SIZE]);
266        payload.extend(content);
267        Self(payload)
268    }
269}
270
271impl<const N: usize> From<&[u8; N]> for PrefixedPayload {
272    fn from(content: &[u8; N]) -> Self {
273        Self::from(&content[..])
274    }
275}
276
277pub(crate) fn read_opaque_message_header(
278    r: &mut Reader<'_>,
279) -> Result<(ContentType, ProtocolVersion, u16), MessageError> {
280    let typ = ContentType::read(r).map_err(|_| MessageError::TooShortForHeader)?;
281    // Don't accept any new content-types.
282    if let ContentType::Unknown(_) = typ {
283        return Err(MessageError::InvalidContentType);
284    }
285
286    let version = ProtocolVersion::read(r).map_err(|_| MessageError::TooShortForHeader)?;
287    // Accept only versions 0x03XX for any XX.
288    match &version {
289        ProtocolVersion::Unknown(v) if (v & 0xff00) != 0x0300 => {
290            return Err(MessageError::UnknownProtocolVersion);
291        }
292        _ => {}
293    };
294
295    let len = u16::read(r).map_err(|_| MessageError::TooShortForHeader)?;
296
297    // Reject undersize messages
298    //  implemented per section 5.1 of RFC8446 (TLSv1.3)
299    //              per section 6.2.1 of RFC5246 (TLSv1.2)
300    if typ != ContentType::ApplicationData && len == 0 {
301        return Err(MessageError::InvalidEmptyPayload);
302    }
303
304    // Reject oversize messages
305    if len >= MAX_PAYLOAD {
306        return Err(MessageError::MessageTooLarge);
307    }
308
309    Ok((typ, version, len))
310}
311
312#[cfg(test)]
313mod tests {
314    use std::{println, vec};
315
316    use super::*;
317
318    #[test]
319    fn split_at_with_single_slice() {
320        let owner: &[u8] = &[0, 1, 2, 3, 4, 5, 6, 7];
321        let borrowed_payload = OutboundChunks::Single(owner);
322
323        let (before, after) = borrowed_payload.split_at(6);
324        println!("before:{before:?}\nafter:{after:?}");
325        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5]);
326        assert_eq!(after.to_vec(), &[6, 7]);
327    }
328
329    #[test]
330    fn split_at_with_multiple_slices() {
331        let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
332        let borrowed_payload = OutboundChunks::new(&owner);
333
334        let (before, after) = borrowed_payload.split_at(3);
335        println!("before:{before:?}\nafter:{after:?}");
336        assert_eq!(before.to_vec(), &[0, 1, 2]);
337        assert_eq!(after.to_vec(), &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
338
339        let (before, after) = borrowed_payload.split_at(8);
340        println!("before:{before:?}\nafter:{after:?}");
341        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7]);
342        assert_eq!(after.to_vec(), &[8, 9, 10, 11, 12]);
343
344        let (before, after) = borrowed_payload.split_at(11);
345        println!("before:{before:?}\nafter:{after:?}");
346        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
347        assert_eq!(after.to_vec(), &[11, 12]);
348    }
349
350    #[test]
351    fn split_out_of_bounds() {
352        let owner: Vec<&[u8]> = vec![&[0, 1, 2, 3], &[4, 5], &[6, 7, 8], &[9, 10, 11, 12]];
353
354        let single_payload = OutboundChunks::Single(owner[0]);
355        let (before, after) = single_payload.split_at(17);
356        println!("before:{before:?}\nafter:{after:?}");
357        assert_eq!(before.to_vec(), &[0, 1, 2, 3]);
358        assert!(after.is_empty());
359
360        let multiple_payload = OutboundChunks::new(&owner);
361        let (before, after) = multiple_payload.split_at(17);
362        println!("before:{before:?}\nafter:{after:?}");
363        assert_eq!(before.to_vec(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
364        assert!(after.is_empty());
365
366        let empty_payload = OutboundChunks::new_empty();
367        let (before, after) = empty_payload.split_at(17);
368        println!("before:{before:?}\nafter:{after:?}");
369        assert!(before.is_empty());
370        assert!(after.is_empty());
371    }
372
373    #[test]
374    fn empty_slices_mixed() {
375        let owner: Vec<&[u8]> = vec![&[], &[], &[0], &[], &[1, 2], &[], &[3], &[4], &[], &[]];
376        let mut borrowed_payload = OutboundChunks::new(&owner);
377        let mut fragment_count = 0;
378        let mut fragment;
379        let expected_fragments: &[&[u8]] = &[&[0, 1], &[2, 3], &[4]];
380
381        while !borrowed_payload.is_empty() {
382            (fragment, borrowed_payload) = borrowed_payload.split_at(2);
383            println!("{fragment:?}");
384            assert_eq!(&expected_fragments[fragment_count], &fragment.to_vec());
385            fragment_count += 1;
386        }
387        assert_eq!(fragment_count, expected_fragments.len());
388    }
389
390    #[test]
391    fn exhaustive_splitting() {
392        let owner: Vec<u8> = (0..127).collect();
393        let slices = (0..7)
394            .map(|i| &owner[((1 << i) - 1)..((1 << (i + 1)) - 1)])
395            .collect::<Vec<_>>();
396        let payload = OutboundChunks::new(&slices);
397
398        assert_eq!(payload.to_vec(), owner);
399        println!("{payload:#?}");
400
401        for start in 0..128 {
402            for end in start..128 {
403                for mid in 0..(end - start) {
404                    let witness = owner[start..end].split_at(mid);
405                    let split_payload = payload
406                        .split_at(end)
407                        .0
408                        .split_at(start)
409                        .1
410                        .split_at(mid);
411                    assert_eq!(
412                        witness.0,
413                        split_payload.0.to_vec(),
414                        "start: {start}, mid:{mid}, end:{end}"
415                    );
416                    assert_eq!(
417                        witness.1,
418                        split_payload.1.to_vec(),
419                        "start: {start}, mid:{mid}, end:{end}"
420                    );
421                }
422            }
423        }
424    }
425}