File: test.rs

package info (click to toggle)
rust-futures-rustls 0.26.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 260 kB
  • sloc: makefile: 7; sh: 1
file content (148 lines) | stat: -rw-r--r-- 4,576 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use futures_rustls::{TlsAcceptor, TlsConnector};
use futures_util::future::TryFutureExt;
use futures_util::io::{copy, AsyncReadExt, AsyncWriteExt};
use lazy_static::lazy_static;
use rustls::ClientConfig;
use rustls_pemfile::{certs, private_key};
use smol::net::{TcpListener, TcpStream};
use smol::Timer;
use std::convert::TryFrom;
use std::io::{BufReader, Cursor};
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::{io, thread};

const CERT: &str = include_str!("end.cert");
const CHAIN: &[u8] = include_bytes!("end.chain");
const RSA: &str = include_str!("end.rsa");

lazy_static! {
    static ref TEST_SERVER: (SocketAddr, &'static str, &'static [u8]) = {
        let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
            .collect::<Result<Vec<_>, _>>()
            .unwrap();
        let key = private_key(&mut BufReader::new(Cursor::new(RSA)))
            .unwrap()
            .unwrap();

        let config = rustls::ServerConfig::builder()
            .with_no_client_auth()
            .with_single_cert(cert, key)
            .unwrap();
        let acceptor = TlsAcceptor::from(Arc::new(config));

        let (send, recv) = channel();

        thread::spawn(move || {
            let done = async move {
                let addr = SocketAddr::from(([127, 0, 0, 1], 0));
                let listener = TcpListener::bind(&addr).await?;

                send.send(listener.local_addr()?).unwrap();

                loop {
                    let (stream, _) = listener.accept().await?;

                    let acceptor = acceptor.clone();
                    let fut = async move {
                        let stream = acceptor.accept(stream).await?;

                        let (mut reader, mut writer) = stream.split();
                        copy(&mut reader, &mut writer).await?;

                        Ok(()) as io::Result<()>
                    }
                    .unwrap_or_else(|err| eprintln!("server: {:?}", err));

                    smol::spawn(fut).detach();
                }
            }
            .unwrap_or_else(|err: io::Error| eprintln!("server: {:?}", err));

            smol::block_on(done);
        });

        let addr = recv.recv().unwrap();
        (addr, "testserver.com", CHAIN)
    };
}

fn start_server() -> &'static (SocketAddr, &'static str, &'static [u8]) {
    &*TEST_SERVER
}

async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
    const FILE: &[u8] = include_bytes!("../README.md");

    let domain = pki_types::ServerName::try_from(domain).unwrap().to_owned();
    let config = TlsConnector::from(config);
    let mut buf = vec![0; FILE.len()];

    let stream = TcpStream::connect(&addr).await?;
    let mut stream = config.connect(domain, stream).await?;
    stream.write_all(FILE).await?;
    stream.flush().await?;
    stream.read_exact(&mut buf).await?;

    assert_eq!(buf, FILE);

    Ok(())
}

#[test]
fn pass() -> io::Result<()> {
    let fut = async {
        let (addr, domain, chain) = start_server();

        // TODO: not sure how to resolve this right now but since
        // TcpStream::bind now returns a future it creates a race
        // condition until its ready sometimes.
        use std::time::*;
        Timer::after(Duration::from_secs(1)).await;

        let chain = certs(&mut std::io::Cursor::new(*chain))
            .collect::<Result<Vec<_>, _>>()
            .unwrap();

        let mut root_store = rustls::RootCertStore::empty();
        root_store.add_parsable_certificates(chain);

        let config = rustls::ClientConfig::builder()
            .with_root_certificates(root_store)
            .with_no_client_auth();
        let config = Arc::new(config);

        start_client(*addr, domain, config).await?;

        Ok(())
    };

    smol::block_on(fut)
}

#[test]
fn fail() -> io::Result<()> {
    let fut = async {
        let (addr, domain, chain) = start_server();

        let chain = certs(&mut std::io::Cursor::new(*chain))
            .collect::<Result<Vec<_>, _>>()
            .unwrap();

        let mut root_store = rustls::RootCertStore::empty();
        root_store.add_parsable_certificates(chain);
        let config = rustls::ClientConfig::builder()
            .with_root_certificates(root_store)
            .with_no_client_auth();
        let config = Arc::new(config);

        assert_ne!(domain, &"google.com");
        let ret = start_client(*addr, "google.com", config).await;
        assert!(ret.is_err());

        Ok(())
    };

    smol::block_on(fut)
}