1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
use {
super::{num_clients, num_concurrent_clients, TestResult, WrapErrExt},
color_eyre::eyre::{bail, Context},
std::{future::Future, sync::Arc},
tokio::{
sync::{
oneshot::{channel, Sender},
Semaphore,
},
task, try_join,
},
};
pub fn test_wrapper(f: impl Future<Output = TestResult> + Send + 'static) -> TestResult {
super::test_wrapper(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.opname("Tokio runtime spawn")?;
rt.block_on(f)
})
}
/// Waits for the leader closure to reach a point where it sends a message for the follower closure,
/// then runs the follower. Captures Eyre errors on both sides and panics if any occur, reporting
/// which side produced the error.
pub async fn drive_pair<T, Ld, Ldf, Fl, Flf>(
leader: Ld,
leader_name: &str,
follower: Fl,
follower_name: &str,
) -> TestResult
where
Ld: FnOnce(Sender<T>) -> Ldf,
Ldf: Future<Output = TestResult>,
Fl: FnOnce(T) -> Flf,
Flf: Future<Output = TestResult>,
{
let (sender, receiver) = channel();
let leading_task = async {
leader(sender).await.with_context(|| format!("{leader_name} exited early with error"))
};
let following_task = async {
let msg = receiver.await?;
follower(msg).await.with_context(|| format!("{follower_name} exited early with error"))
};
try_join!(leading_task, following_task).map(|((), ())| ())
}
pub async fn drive_server_and_multiple_clients<T, Srv, Srvf, Clt, Cltf>(
server: Srv,
client: Clt,
) -> TestResult
where
T: Send + Sync + ?Sized + 'static,
Srv: FnOnce(Sender<Arc<T>>, u32) -> Srvf + Send + 'static,
Srvf: Future<Output = TestResult>,
Clt: Fn(Arc<T>) -> Cltf + Send + Sync + 'static,
Cltf: Future<Output = TestResult> + Send,
{
let num_clients = num_clients();
let client_wrapper = |msg| async move {
let client = Arc::new(client);
let choke = Arc::new(Semaphore::new(num_concurrent_clients().try_into().unwrap()));
let mut client_tasks = Vec::with_capacity(num_clients.try_into().unwrap());
for _ in 0..num_clients {
let permit = Arc::clone(&choke).acquire_owned().await.unwrap();
let clientc = Arc::clone(&client);
let msgc = Arc::clone(&msg);
let jhndl = task::spawn(async move {
let _prm = permit; // Send to other thread to drop when client finishes
clientc(msgc).await
});
client_tasks.push(jhndl);
}
for client in client_tasks {
let Ok(rslt) = client.await else {
bail!("client task panicked");
};
rslt?; // Early-return the first error; context not necessary as drive_pair does it
}
Ok::<(), color_eyre::eyre::Error>(())
};
let server_wrapper = move |sender: Sender<Arc<T>>| server(sender, num_clients);
drive_pair(server_wrapper, "server", client_wrapper, "client").await
}
|