File: halide_thread_pool.h

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (161 lines) | stat: -rw-r--r-- 4,870 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#ifndef HALIDE_THREAD_POOL_H
#define HALIDE_THREAD_POOL_H

#include <cassert>
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
#include <utility>

#ifdef _MSC_VER
#else
#include <unistd.h>
#endif

/** \file
 * Define a simple thread pool utility that is modeled on the api of
 * std::async(); since implementation details of std::async
 * can vary considerably, with no control over thread spawning, this class
 * allows us to use the same model but with precise control over thread usage.
 *
 * A ThreadPool is created with a specific number of threads, which will never
 * vary over the life of the ThreadPool. (If created without a specific number
 * of threads, it will attempt to use threads == number-of-cores.)
 *
 * Each async request will go into a queue, and will be serviced by the next
 * available thread from the pool.
 *
 * The ThreadPool's dtor will block until all currently-executing tasks
 * to finish (but won't schedule any more).
 *
 * Note that this is a fairly simpleminded ThreadPool, meant for tasks
 * that are fairly coarse (e.g. different tasks in a test); it is specifically
 * *not* intended to be the underlying implementation for Halide runtime threads
 */
namespace Halide {
namespace Tools {

template<typename T>
class ThreadPool {
    struct Job {
        std::function<T()> func;
        std::promise<T> result;

        void run_unlocked(std::unique_lock<std::mutex> &unique_lock);
    };

    // all fields are protected by this mutex.
    std::mutex mutex;

    // Queue of Jobs.
    std::queue<Job> jobs;

    // Broadcast whenever items are added to the Job queue.
    std::condition_variable wakeup_threads;

    // Keep track of threads so they can be joined at shutdown
    std::vector<std::thread> threads;

    // True if the pool is shutting down.
    bool shutting_down{false};

    void worker_thread() {
        std::unique_lock<std::mutex> unique_lock(mutex);
        while (!shutting_down) {
            if (jobs.empty()) {
                // There are no jobs pending. Wait until more jobs are enqueued.
                wakeup_threads.wait(unique_lock);
            } else {
                // Grab the next job.
                Job cur_job = std::move(jobs.front());
                jobs.pop();
                cur_job.run_unlocked(unique_lock);
            }
        }
    }

public:
    static size_t num_processors_online() {
#ifdef _WIN32
        char *num_cores = getenv("NUMBER_OF_PROCESSORS");
        return num_cores ? atoi(num_cores) : 8;
#else
        return sysconf(_SC_NPROCESSORS_ONLN);
#endif
    }

    // Default to number of available cores if not specified otherwise
    ThreadPool(size_t desired_num_threads = num_processors_online()) {
        // This file doesn't depend on anything else in libHalide, so
        // we'll use assert, not internal_assert.
        assert(desired_num_threads > 0);

        std::lock_guard<std::mutex> lock(mutex);

        // Create all the threads.
        for (size_t i = 0; i < desired_num_threads; ++i) {
            threads.emplace_back([this] { worker_thread(); });
        }
    }

    ~ThreadPool() {
        // Wake everyone up and tell them the party's over and it's time to go home
        {
            std::lock_guard<std::mutex> lock(mutex);
            shutting_down = true;
            wakeup_threads.notify_all();
        }

        // Wait until they leave
        for (auto &t : threads) {
            t.join();
        }
    }

    template<typename Func, typename... Args>
    std::future<T> async(Func func, Args... args) {
        std::lock_guard<std::mutex> lock(mutex);

        Job job;
        // Don't use std::forward here: we never want args passed by reference,
        // since they will be accessed from an arbitrary thread.
        //
        // Some versions of GCC won't allow capturing variadic arguments in a lambda;
        //
        //     job.func = [func, args...]() -> T { return func(args...); };  // Nope, sorry
        //
        // fortunately, we can use std::bind() to accomplish the same thing.
        job.func = std::bind(func, args...);
        jobs.emplace(std::move(job));
        std::future<T> result = jobs.back().result.get_future();

        // Wake up our threads.
        wakeup_threads.notify_all();

        return result;
    }
};

template<typename T>
inline void ThreadPool<T>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
    unique_lock.unlock();
    T r = func();
    unique_lock.lock();
    result.set_value(std::move(r));
}

template<>
inline void ThreadPool<void>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
    unique_lock.unlock();
    func();
    unique_lock.lock();
    result.set_value();
}

}  // namespace Tools
}  // namespace Halide

#endif  // HALIDE_THREAD_POOL_H