1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use std::{io, task::Poll::*};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub mod common;
const TEST_DATA: &[u8] = "hello world".as_bytes();
// The maximum TLS record payload is 2^14 bytes.
// Send more to ensure multiple records.
const LARGE_TEST_DATA: &[u8] = &[5; (1 << 15)];
#[tokio::test]
async fn send_and_recv_basic() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
client.write_all(TEST_DATA).await?;
let mut received = [0; TEST_DATA.len()];
assert_eq!(server.read_exact(&mut received).await?, TEST_DATA.len());
assert_eq!(TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_into_vec() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
client.write_all(TEST_DATA).await?;
let mut received = vec![];
while received.len() < TEST_DATA.len() {
let bytes_read = server.read_buf(&mut received).await?;
assert!(bytes_read > 0);
}
assert_eq!(TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_multiple_records() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
let mut received = [0; LARGE_TEST_DATA.len()];
let (_, read_size) = tokio::try_join!(
client.write_all(LARGE_TEST_DATA),
server.read_exact(&mut received)
)?;
assert_eq!(LARGE_TEST_DATA.len(), read_size);
assert_eq!(LARGE_TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_split() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (client, server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
let (mut client_read, mut client_write) = tokio::io::split(client);
let (mut server_read, mut server_write) = tokio::io::split(server);
let mut client_received = [0; LARGE_TEST_DATA.len()];
let mut server_received = [0; LARGE_TEST_DATA.len()];
let (_, _, client_bytes, server_bytes) = tokio::try_join!(
client_write.write_all(LARGE_TEST_DATA),
server_write.write_all(LARGE_TEST_DATA),
client_read.read_exact(&mut client_received),
server_read.read_exact(&mut server_received)
)?;
assert_eq!(client_bytes, LARGE_TEST_DATA.len());
assert_eq!(server_bytes, LARGE_TEST_DATA.len());
assert_eq!(LARGE_TEST_DATA, client_received);
assert_eq!(LARGE_TEST_DATA, server_received);
Ok(())
}
#[tokio::test]
async fn send_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (server_stream, client_stream) = common::get_streams().await?;
let client_stream = common::TestStream::new(client_stream);
let overrides = client_stream.overrides();
let (mut client, _) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
// Setup write to fail
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset)))
})));
// Verify write fails
let result = client.write_all(TEST_DATA).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn recv_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (server_stream, client_stream) = common::get_streams().await?;
let client_stream = common::TestStream::new(client_stream);
let overrides = client_stream.overrides();
let (mut client, _) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
// Setup read to fail
overrides.next_read(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset)))
})));
// Verify read fails
let mut received = [0; 1];
let result = client.read_exact(&mut received).await;
assert!(result.is_err());
Ok(())
}
|