File: with-metadata.rs

package info (click to toggle)
rust-async-task 4.7.1-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 368 kB
  • sloc: makefile: 2; sh: 1
file content (145 lines) | stat: -rw-r--r-- 3,836 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
//! A single threaded executor that uses shortest-job-first scheduling.

use std::cell::RefCell;
use std::collections::BinaryHeap;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::thread;
use std::time::{Duration, Instant};
use std::{cell::Cell, future::Future};

use async_task::{Builder, Runnable, Task};
use pin_project_lite::pin_project;
use smol::{channel, future};

struct ByDuration(Runnable<DurationMetadata>);

impl ByDuration {
    fn duration(&self) -> Duration {
        self.0.metadata().inner.get()
    }
}

impl PartialEq for ByDuration {
    fn eq(&self, other: &Self) -> bool {
        self.duration() == other.duration()
    }
}

impl Eq for ByDuration {}

impl PartialOrd for ByDuration {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for ByDuration {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.duration().cmp(&other.duration()).reverse()
    }
}

pin_project! {
    #[must_use = "futures do nothing unless you `.await` or poll them"]
    struct MeasureRuntime<'a, F> {
        #[pin]
        f: F,
        duration: &'a Cell<Duration>
    }
}

impl<'a, F: Future> Future for MeasureRuntime<'a, F> {
    type Output = F::Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let duration_cell: &Cell<Duration> = this.duration;
        let start = Instant::now();
        let res = F::poll(this.f, cx);
        let new_duration = Instant::now() - start;
        duration_cell.set(duration_cell.get() / 2 + new_duration / 2);
        res
    }
}

pub struct DurationMetadata {
    inner: Cell<Duration>,
}

thread_local! {
    // A queue that holds scheduled tasks.
    static QUEUE: RefCell<BinaryHeap<ByDuration>> = RefCell::new(BinaryHeap::new());
}

fn make_future_fn<'a, F>(
    future: F,
) -> impl (FnOnce(&'a DurationMetadata) -> MeasureRuntime<'a, F>) {
    move |duration_meta| MeasureRuntime {
        f: future,
        duration: &duration_meta.inner,
    }
}

fn ensure_safe_schedule<F: Send + Sync + 'static>(f: F) -> F {
    f
}

/// Spawns a future on the executor.
pub fn spawn<F, T>(future: F) -> Task<T, DurationMetadata>
where
    F: Future<Output = T> + 'static,
    T: 'static,
{
    let spawn_thread_id = thread::current().id();
    // Create a task that is scheduled by pushing it into the queue.
    let schedule = ensure_safe_schedule(move |runnable| {
        if thread::current().id() != spawn_thread_id {
            panic!("Task would be run on a different thread than spawned on.");
        }
        QUEUE.with(move |queue| queue.borrow_mut().push(ByDuration(runnable)));
    });
    let future_fn = make_future_fn(future);
    let (runnable, task) = unsafe {
        Builder::new()
            .metadata(DurationMetadata {
                inner: Cell::new(Duration::default()),
            })
            .spawn_unchecked(future_fn, schedule)
    };

    // Schedule the task by pushing it into the queue.
    runnable.schedule();

    task
}

pub fn block_on<F>(future: F)
where
    F: Future<Output = ()> + 'static,
{
    let task = spawn(future);
    while !task.is_finished() {
        let Some(runnable) = QUEUE.with(|queue| queue.borrow_mut().pop()) else {
            thread::yield_now();
            continue;
        };
        runnable.0.run();
    }
}

fn main() {
    // Spawn a future and await its result.
    block_on(async {
        let (sender, receiver) = channel::bounded(1);
        let world = spawn(async move {
            receiver.recv().await.unwrap();
            println!("world.")
        });
        let hello = spawn(async move {
            sender.send(()).await.unwrap();
            print!("Hello, ")
        });
        future::zip(hello, world).await;
    });
}