1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use rand::Rng;
use s2n_tls::{
config::Config,
connection::{Connection, ModifiedBuilder},
enums::{ClientAuthType, Mode, Version},
error::{Error, ErrorType},
pool::ConfigPoolBuilder,
security::DEFAULT_TLS13,
};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use std::{collections::VecDeque, time::Duration};
use tokio::time;
pub mod common;
#[tokio::test]
async fn handshake_basic() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (client_result, server_result) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
for tls in [client_result, server_result] {
// Security policy ensures TLS1.3.
assert_eq!(tls.as_ref().actual_protocol_version()?, Version::TLS13);
// Handshake types may change, but will at least be negotiated.
assert!(tls.as_ref().handshake_type()?.contains("NEGOTIATED"));
// Cipher suite may change, so just makes sure we can retrieve it.
assert!(tls.as_ref().cipher_suite().is_ok());
assert!(tls.as_ref().selected_curve().is_ok());
}
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
async fn handshake_with_pool_multithread() -> Result<(), Box<dyn std::error::Error>> {
const COUNT: usize = 20;
const CLIENT_LIMIT: usize = 3;
let client_config = common::client_config()?.build()?;
let server_config = common::server_config()?.build()?;
let mut client_pool = ConfigPoolBuilder::new(Mode::Client, client_config);
client_pool.set_max_pool_size(CLIENT_LIMIT);
let client_pool = client_pool.build();
let server_pool = ConfigPoolBuilder::new(Mode::Server, server_config).build();
let client = TlsConnector::new(client_pool.clone());
let server = TlsAcceptor::new(server_pool.clone());
let mut tasks = VecDeque::new();
for _ in 0..COUNT {
let client = client.clone();
let server = server.clone();
tasks.push_back(tokio::spawn(async move {
// Start each handshake at a randomly determined time
let rand = rand::thread_rng().gen_range(0..50);
time::sleep(Duration::from_millis(rand)).await;
let (server_stream, client_stream) = common::get_streams().await.unwrap();
common::run_negotiate(&client, client_stream, &server, server_stream).await
}));
}
for task in tasks {
task.await??;
}
Ok(())
}
#[tokio::test]
async fn handshake_with_connection_config() -> Result<(), Box<dyn std::error::Error>> {
// Setup the client with a method
fn with_client_auth(conn: &mut Connection) -> Result<&mut Connection, Error> {
conn.set_client_auth_type(ClientAuthType::Optional)
}
let client_builder = ModifiedBuilder::new(common::client_config()?.build()?, with_client_auth);
// Setup the server with a closure
let server_builder = ModifiedBuilder::new(common::server_config()?.build()?, |conn| {
conn.set_client_auth_type(ClientAuthType::Optional)
});
let client = TlsConnector::new(client_builder);
let server = TlsAcceptor::new(server_builder);
let (server_stream, client_stream) = common::get_streams().await?;
let (client_result, server_result) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
for tls in [client_result, server_result] {
assert!(tls.as_ref().handshake_type()?.contains("CLIENT_AUTH"));
}
Ok(())
}
#[tokio::test]
async fn handshake_with_connection_config_with_pool() -> Result<(), Box<dyn std::error::Error>> {
fn with_client_auth(conn: &mut Connection) -> Result<&mut Connection, Error> {
conn.set_client_auth_type(ClientAuthType::Optional)
}
let client_builder = ModifiedBuilder::new(common::client_config()?.build()?, with_client_auth);
let server_pool =
ConfigPoolBuilder::new(Mode::Server, common::server_config()?.build()?).build();
let server_builder = ModifiedBuilder::new(server_pool, with_client_auth);
let client = TlsConnector::new(client_builder);
let server = TlsAcceptor::new(server_builder);
for _ in 0..5 {
let (server_stream, client_stream) = common::get_streams().await?;
let (_, server_result) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
assert!(server_result
.as_ref()
.handshake_type()?
.contains("CLIENT_AUTH"));
}
Ok(())
}
#[tokio::test]
async fn handshake_error() -> Result<(), Box<dyn std::error::Error>> {
// Config::default() does not include any RSA certificates,
// but only provides TLS1.2 cipher suites that require RSA auth.
// The server will fail to choose a cipher suite, but
// S2N_ERR_CIPHER_NOT_SUPPORTED is specifically excluded from blinding.
let bad_config = Config::default();
let client_config = common::client_config()?.build()?;
let server_config = bad_config;
let client = TlsConnector::new(client_config);
let server = TlsAcceptor::new(server_config);
let (server_stream, client_stream) = common::get_streams().await?;
let result = common::run_negotiate(&client, client_stream, &server, server_stream).await;
assert!(matches!(result, Err(e) if !e.is_retryable()));
Ok(())
}
#[tokio::test(start_paused = true)]
async fn handshake_error_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
let clock = common::TokioTime::default();
// Config::builder() does not include a trust store.
// The client will reject the server certificate as untrusted.
let mut bad_config = Config::builder();
bad_config.set_security_policy(&DEFAULT_TLS13)?;
bad_config.set_monotonic_clock(clock)?;
let client_config = bad_config.build()?;
let server_config = common::server_config()?.build()?;
let client = TlsConnector::new(client_config.clone());
let server = TlsAcceptor::new(server_config.clone());
let (server_stream, client_stream) = common::get_streams().await?;
let time_start = time::Instant::now();
let result = common::run_negotiate(&client, client_stream, &server, server_stream).await;
let time_elapsed = time_start.elapsed();
// Handshake MUST NOT finish faster than minimal blinding time.
assert!(time_elapsed > common::MIN_BLINDING_SECS);
// Handshake MUST eventually gracefully fail after blinding
let error = result.unwrap_err();
assert_eq!(error.kind(), ErrorType::ProtocolError);
Ok(())
}
#[tokio::test]
async fn io_stream_access() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let client_addr = client_stream.local_addr().unwrap();
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client_result, _server_result) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
assert_eq!(client_result.get_ref().local_addr().unwrap(), client_addr);
assert_eq!(client_result.get_mut().local_addr().unwrap(), client_addr);
Ok(())
}
|