rustls/
stream.rs

1use core::ops::{Deref, DerefMut};
2use std::io::{BufRead, IoSlice, Read, Result, Write};
3
4use crate::conn::{ConnectionCommon, SideData};
5
6/// This type implements `io::Read` and `io::Write`, encapsulating
7/// a Connection `C` and an underlying transport `T`, such as a socket.
8///
9/// Relies on [`ConnectionCommon::complete_io()`] to perform the necessary I/O.
10///
11/// This allows you to use a rustls Connection like a normal stream.
12#[allow(clippy::exhaustive_structs)]
13#[derive(Debug)]
14pub struct Stream<'a, C: 'a + ?Sized, T: 'a + Read + Write + ?Sized> {
15    /// Our TLS connection
16    pub conn: &'a mut C,
17
18    /// The underlying transport, like a socket
19    pub sock: &'a mut T,
20}
21
22impl<'a, C, T, S> Stream<'a, C, T>
23where
24    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
25    T: 'a + Read + Write,
26    S: SideData,
27{
28    /// Make a new Stream using the Connection `conn` and socket-like object
29    /// `sock`.  This does not fail and does no IO.
30    pub fn new(conn: &'a mut C, sock: &'a mut T) -> Self {
31        Self { conn, sock }
32    }
33
34    /// If we're handshaking, complete all the IO for that.
35    /// If we have data to write, write it all.
36    fn complete_prior_io(&mut self) -> Result<()> {
37        if self.conn.is_handshaking() {
38            self.conn.complete_io(self.sock)?;
39        }
40
41        if self.conn.wants_write() {
42            self.conn.complete_io(self.sock)?;
43        }
44
45        Ok(())
46    }
47
48    fn prepare_read(&mut self) -> Result<()> {
49        self.complete_prior_io()?;
50
51        // We call complete_io() in a loop since a single call may read only
52        // a partial packet from the underlying transport. A full packet is
53        // needed to get more plaintext, which we must do if EOF has not been
54        // hit.
55        while self.conn.wants_read() {
56            if self.conn.complete_io(self.sock)?.0 == 0 {
57                break;
58            }
59        }
60
61        Ok(())
62    }
63
64    // Implements `BufRead::fill_buf` but with more flexible lifetimes, so StreamOwned can reuse it
65    fn fill_buf(mut self) -> Result<&'a [u8]>
66    where
67        S: 'a,
68    {
69        self.prepare_read()?;
70        self.conn.reader().into_first_chunk()
71    }
72}
73
74impl<'a, C, T, S> Read for Stream<'a, C, T>
75where
76    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
77    T: 'a + Read + Write,
78    S: SideData,
79{
80    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
81        self.prepare_read()?;
82        self.conn.reader().read(buf)
83    }
84}
85
86impl<'a, C, T, S> BufRead for Stream<'a, C, T>
87where
88    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
89    T: 'a + Read + Write,
90    S: 'a + SideData,
91{
92    fn fill_buf(&mut self) -> Result<&[u8]> {
93        // reborrow to get an owned `Stream`
94        Stream {
95            conn: self.conn,
96            sock: self.sock,
97        }
98        .fill_buf()
99    }
100
101    fn consume(&mut self, amt: usize) {
102        self.conn.reader().consume(amt)
103    }
104}
105
106impl<'a, C, T, S> Write for Stream<'a, C, T>
107where
108    C: 'a + DerefMut + Deref<Target = ConnectionCommon<S>>,
109    T: 'a + Read + Write,
110    S: SideData,
111{
112    fn write(&mut self, buf: &[u8]) -> Result<usize> {
113        self.complete_prior_io()?;
114
115        let len = self.conn.writer().write(buf)?;
116
117        // Try to write the underlying transport here, but don't let
118        // any errors mask the fact we've consumed `len` bytes.
119        // Callers will learn of permanent errors on the next call.
120        let _ = self.conn.complete_io(self.sock);
121
122        Ok(len)
123    }
124
125    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
126        self.complete_prior_io()?;
127
128        let len = self
129            .conn
130            .writer()
131            .write_vectored(bufs)?;
132
133        // Try to write the underlying transport here, but don't let
134        // any errors mask the fact we've consumed `len` bytes.
135        // Callers will learn of permanent errors on the next call.
136        let _ = self.conn.complete_io(self.sock);
137
138        Ok(len)
139    }
140
141    fn flush(&mut self) -> Result<()> {
142        self.complete_prior_io()?;
143
144        self.conn.writer().flush()?;
145        if self.conn.wants_write() {
146            self.conn.complete_io(self.sock)?;
147        }
148        Ok(())
149    }
150}
151
152/// This type implements `io::Read` and `io::Write`, encapsulating
153/// and owning a Connection `C` and an underlying transport `T`, such as a socket.
154///
155/// Relies on [`ConnectionCommon::complete_io()`] to perform the necessary I/O.
156///
157/// This allows you to use a rustls Connection like a normal stream.
158#[allow(clippy::exhaustive_structs)]
159#[derive(Debug)]
160pub struct StreamOwned<C: Sized, T: Read + Write + Sized> {
161    /// Our connection
162    pub conn: C,
163
164    /// The underlying transport, like a socket
165    pub sock: T,
166}
167
168impl<C, T, S> StreamOwned<C, T>
169where
170    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
171    T: Read + Write,
172    S: SideData,
173{
174    /// Make a new StreamOwned taking the Connection `conn` and socket-like
175    /// object `sock`.  This does not fail and does no IO.
176    ///
177    /// This is the same as `Stream::new` except `conn` and `sock` are
178    /// moved into the StreamOwned.
179    pub fn new(conn: C, sock: T) -> Self {
180        Self { conn, sock }
181    }
182
183    /// Get a reference to the underlying socket
184    pub fn get_ref(&self) -> &T {
185        &self.sock
186    }
187
188    /// Get a mutable reference to the underlying socket
189    pub fn get_mut(&mut self) -> &mut T {
190        &mut self.sock
191    }
192
193    /// Extract the `conn` and `sock` parts from the `StreamOwned`
194    pub fn into_parts(self) -> (C, T) {
195        (self.conn, self.sock)
196    }
197}
198
199impl<'a, C, T, S> StreamOwned<C, T>
200where
201    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
202    T: Read + Write,
203    S: SideData,
204{
205    fn as_stream(&'a mut self) -> Stream<'a, C, T> {
206        Stream {
207            conn: &mut self.conn,
208            sock: &mut self.sock,
209        }
210    }
211}
212
213impl<C, T, S> Read for StreamOwned<C, T>
214where
215    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
216    T: Read + Write,
217    S: SideData,
218{
219    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
220        self.as_stream().read(buf)
221    }
222}
223
224impl<C, T, S> BufRead for StreamOwned<C, T>
225where
226    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
227    T: Read + Write,
228    S: 'static + SideData,
229{
230    fn fill_buf(&mut self) -> Result<&[u8]> {
231        self.as_stream().fill_buf()
232    }
233
234    fn consume(&mut self, amt: usize) {
235        self.as_stream().consume(amt)
236    }
237}
238
239impl<C, T, S> Write for StreamOwned<C, T>
240where
241    C: DerefMut + Deref<Target = ConnectionCommon<S>>,
242    T: Read + Write,
243    S: SideData,
244{
245    fn write(&mut self, buf: &[u8]) -> Result<usize> {
246        self.as_stream().write(buf)
247    }
248
249    fn flush(&mut self) -> Result<()> {
250        self.as_stream().flush()
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use std::net::TcpStream;
257
258    use super::{Stream, StreamOwned};
259    use crate::client::ClientConnection;
260    use crate::server::ServerConnection;
261
262    #[test]
263    fn stream_can_be_created_for_connection_and_tcpstream() {
264        type _Test<'a> = Stream<'a, ClientConnection, TcpStream>;
265    }
266
267    #[test]
268    fn streamowned_can_be_created_for_client_and_tcpstream() {
269        type _Test = StreamOwned<ClientConnection, TcpStream>;
270    }
271
272    #[test]
273    fn streamowned_can_be_created_for_server_and_tcpstream() {
274        type _Test = StreamOwned<ServerConnection, TcpStream>;
275    }
276}