1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use std::{
io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};
type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;
type ShutdownFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context) -> Poll<io::Result<()>>>;
#[derive(Default)]
struct OverrideMethods {
next_read: Option<ReadFn>,
next_write: Option<WriteFn>,
next_shutdown: Option<ShutdownFn>,
}
#[derive(Default)]
pub struct Overrides(Mutex<OverrideMethods>);
impl Overrides {
pub fn next_read(&self, input: Option<ReadFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_read = input;
}
}
pub fn next_write(&self, input: Option<WriteFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_write = input;
}
}
pub fn next_shutdown(&self, input: Option<ShutdownFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_shutdown = input;
}
}
pub fn is_consumed(&self) -> bool {
if let Ok(overrides) = self.0.lock() {
overrides.next_read.is_none()
&& overrides.next_write.is_none()
&& overrides.next_shutdown.is_none()
} else {
false
}
}
}
unsafe impl Send for Overrides {}
unsafe impl Sync for Overrides {}
pub struct TestStream {
stream: TcpStream,
overrides: Arc<Overrides>,
}
impl TestStream {
pub fn new(stream: TcpStream) -> Self {
let overrides = Arc::new(Overrides::default());
Self { stream, overrides }
}
pub fn overrides(&self) -> Arc<Overrides> {
self.overrides.clone()
}
}
impl AsyncRead for TestStream {
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_read.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_read(ctx, buf)
}
}
}
impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_write.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_write(ctx, buf)
}
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(ctx)
}
fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_shutdown.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx)
} else {
stream.poll_shutdown(ctx)
}
}
}
|