
|
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use std::{io, task::Poll::*};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub mod common;
const TEST_DATA: &[u8] = "hello world".as_bytes();
// The maximum TLS record payload is 2^14 bytes.
// Send more to ensure multiple records.
const LARGE_TEST_DATA: &[u8] = &[5; (1 << 15)];
#[tokio::test]
async fn send_and_recv_basic() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
client.write_all(TEST_DATA).await?;
let mut received = [0; TEST_DATA.len()];
assert_eq!(server.read_exact(&mut received).await?, TEST_DATA.len());
assert_eq!(TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_into_vec() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
client.write_all(TEST_DATA).await?;
let mut received = vec![];
while received.len() < TEST_DATA.len() {
let bytes_read = server.read_buf(&mut received).await?;
assert!(bytes_read > 0);
}
assert_eq!(TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_multiple_records() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (mut client, mut server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
let mut received = [0; LARGE_TEST_DATA.len()];
let (_, read_size) = tokio::try_join!(
client.write_all(LARGE_TEST_DATA),
server.read_exact(&mut received)
)?;
assert_eq!(LARGE_TEST_DATA.len(), read_size);
assert_eq!(LARGE_TEST_DATA, received);
Ok(())
}
#[tokio::test]
async fn send_and_recv_split() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = common::get_streams().await?;
let connector = TlsConnector::new(common::client_config()?.build()?);
let acceptor = TlsAcceptor::new(common::server_config()?.build()?);
let (client, server) =
common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?;
let (mut client_read, mut client_write) = tokio::io::split(client);
let (mut server_read, mut server_write) = tokio::io::split(server);
let mut client_received = [0; LARGE_TEST_DATA.len()];
let mut server_received = [0; LARGE_TEST_DATA.len()];
let (_, _, client_bytes, server_bytes) = tokio::try_join!(
client_write.write_all(LARGE_TEST_DATA),
server_write.write_all(LARGE_TEST_DATA),
client_read.read_exact(&mut client_received),
server_read.read_exact(&mut server_received)
)?;
assert_eq!(client_bytes, LARGE_TEST_DATA.len());
assert_eq!(server_bytes, LARGE_TEST_DATA.len());
assert_eq!(LARGE_TEST_DATA, client_received);
assert_eq!(LARGE_TEST_DATA, server_received);
Ok(())
}
#[tokio::test]
async fn send_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (server_stream, client_stream) = common::get_streams().await?;
let client_stream = common::TestStream::new(client_stream);
let overrides = client_stream.overrides();
let (mut client, _) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
// Setup write to fail
overrides.next_write(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset)))
})));
// Verify write fails
let result = client.write_all(TEST_DATA).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn recv_error() -> Result<(), Box<dyn std::error::Error>> {
let client = TlsConnector::new(common::client_config()?.build()?);
let server = TlsAcceptor::new(common::server_config()?.build()?);
let (server_stream, client_stream) = common::get_streams().await?;
let client_stream = common::TestStream::new(client_stream);
let overrides = client_stream.overrides();
let (mut client, _) =
common::run_negotiate(&client, client_stream, &server, server_stream).await?;
// Setup read to fail
overrides.next_read(Some(Box::new(|_, _, _| {
Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset)))
})));
// Verify read fails
let mut received = [0; 1];
let result = client.read_exact(&mut received).await;
assert!(result.is_err());
Ok(())
}
|