File: utils.rs

package info (click to toggle)
rust-tokio-rustls 0.26.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 324 kB
  • sloc: makefile: 2
file content (154 lines) | stat: -rw-r--r-- 5,241 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
mod utils {
    use std::collections::VecDeque;
    use std::io::IoSlice;
    use std::pin::Pin;
    use std::task::{Context, Poll};

    use rustls::{
        pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer},
        ClientConfig, RootCertStore, ServerConfig,
    };
    use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};

    #[allow(dead_code)]
    pub(crate) fn make_configs() -> (ServerConfig, ClientConfig) {
        // A test root certificate that is the trust anchor for the CHAIN.
        const ROOT: &str = include_str!("certs/root.pem");
        // A server certificate chain that includes both an end-entity server certificate
        // and the intermediate certificate that issued it. The ROOT is configured
        // out-of-band.
        const CHAIN: &str = include_str!("certs/chain.pem");
        // A private key corresponding to the end-entity server certificate in CHAIN.
        const EE_KEY: &str = include_str!("certs/end.key");

        let cert = CertificateDer::pem_slice_iter(CHAIN.as_bytes())
            .collect::<Result<Vec<_>, _>>()
            .unwrap();
        let key = PrivateKeyDer::from_pem_slice(EE_KEY.as_bytes()).unwrap();
        let sconfig = ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(cert, key)
            .unwrap();

        let mut client_root_cert_store = RootCertStore::empty();
        for root in CertificateDer::pem_slice_iter(ROOT.as_bytes()) {
            client_root_cert_store.add(root.unwrap()).unwrap();
        }

        let cconfig = ClientConfig::builder()
            .with_root_certificates(client_root_cert_store)
            .with_no_client_auth();

        (sconfig, cconfig)
    }

    #[allow(dead_code)]
    pub(crate) async fn write<W: AsyncWrite + Unpin>(
        w: &mut W,
        data: &[u8],
        vectored: bool,
    ) -> io::Result<()> {
        if !vectored {
            return w.write_all(data).await;
        }

        let mut data = data;

        while !data.is_empty() {
            let chunk_size = (data.len() / 4).max(1);
            let vectors = data
                .chunks(chunk_size)
                .map(IoSlice::new)
                .collect::<Vec<_>>();
            let written = w.write_vectored(&vectors).await?;
            data = &data[written..];
        }

        Ok(())
    }

    #[allow(dead_code)]
    pub(crate) const TEST_SERVER_DOMAIN: &str = "foobar.com";

    /// An IO wrapper that never flushes when writing, and always returns pending on first flush.
    ///
    /// This is used to test that rustls always flushes to completion during handshake.
    pub(crate) struct FlushWrapper<S> {
        stream: S,
        buf: VecDeque<Vec<u8>>,
        queued: Vec<u8>,
    }

    impl<S> FlushWrapper<S> {
        #[allow(dead_code)]
        pub(crate) fn new(stream: S) -> Self {
            Self {
                stream,
                buf: VecDeque::new(),
                queued: Vec::new(),
            }
        }
    }

    impl<S: AsyncRead + Unpin> AsyncRead for FlushWrapper<S> {
        fn poll_read(
            self: Pin<&mut Self>,
            cx: &mut Context<'_>,
            buf: &mut tokio::io::ReadBuf<'_>,
        ) -> Poll<io::Result<()>> {
            Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
        }
    }

    impl<S: AsyncWrite + Unpin> FlushWrapper<S> {
        fn poll_flush_inner<F>(
            &mut self,
            cx: &mut Context<'_>,
            flush_inner: F,
        ) -> Poll<Result<(), io::Error>>
        where
            F: FnOnce(Pin<&mut S>, &mut Context<'_>) -> Poll<Result<(), io::Error>>,
        {
            loop {
                let stream = Pin::new(&mut self.stream);
                if !self.queued.is_empty() {
                    // write out the queued data
                    let n = std::task::ready!(stream.poll_write(cx, &self.queued))?;
                    self.queued = self.queued[n..].to_vec();
                } else if let Some(buf) = self.buf.pop_front() {
                    // queue the flush, but don't trigger the write immediately.
                    self.queued = buf;
                    cx.waker().wake_by_ref();
                    return Poll::Pending;
                } else {
                    // nothing more to flush to the inner stream, flush the inner stream instead.
                    return flush_inner(stream, cx);
                }
            }
        }
    }

    impl<S: AsyncWrite + Unpin> AsyncWrite for FlushWrapper<S> {
        fn poll_write(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
            buf: &[u8],
        ) -> Poll<Result<usize, io::Error>> {
            self.get_mut().buf.push_back(buf.to_vec());
            Poll::Ready(Ok(buf.len()))
        }

        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
            self.get_mut()
                .poll_flush_inner(cx, |s, cx| s.poll_flush(cx))
        }

        fn poll_shutdown(
            self: Pin<&mut Self>,
            cx: &mut Context<'_>,
        ) -> Poll<Result<(), io::Error>> {
            self.get_mut()
                .poll_flush_inner(cx, |s, cx| s.poll_shutdown(cx))
        }
    }
}