File: threadpool.h

package info (click to toggle)
bowtie2 2.5.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 27,556 kB
  • sloc: cpp: 64,301; perl: 7,232; sh: 1,131; python: 993; makefile: 606; ansic: 122
file content (134 lines) | stat: -rw-r--r-- 3,789 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#ifndef _THREAD_POOL_H_
#define _THREAD_POOL_H_

#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <iostream>
#include <memory>
#include <map>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>

template<typename T>
class threadsafe_queue {
private:
	mutable std::mutex mut;
	std::queue<T> data_queue;

public:
	threadsafe_queue() {}

	void push(T &&new_value) {
		std::lock_guard<std::mutex> lk(mut);
		data_queue.emplace(new_value);
	}

	bool try_pop(T& value) {
		std::lock_guard<std::mutex> lk(mut);
		if (data_queue.empty())
			return false;
		value = std::move(data_queue.front());
		data_queue.pop();
		return true;
	}

	size_t size() {
		std::lock_guard<std::mutex> lk(mut);
		return data_queue.size();
	}
};

class thread_pool
{
	std::atomic_bool done;
	int nthreads;
        std::map<std::thread::id, int> thread_id;
	threadsafe_queue<std::function<void()>> work_queue;
	std::vector<std::thread> threads;
        std::condition_variable cv;
        std::mutex m;

	void worker_thread() {
		while (!done) {
			std::function<void()> task;
                        if (work_queue.try_pop(task)) {
				task();
                        } else {
                                std::unique_lock<std::mutex> lock(m);
                                cv.wait(lock, [&] {return work_queue.size() != 0 || done; });
                        }
                }
	}
public:
	thread_pool(int nthr):
		done(false), nthreads(nthr)
		{
			try {
                                for (int i = 0; i < nthreads; ++i) {
                                        threads.emplace_back(
                                                std::thread(&thread_pool::worker_thread, this));
                                        thread_id[threads[i].get_id()] = i;
                                }
                        } catch (...) {
				done = true;
				throw;
			}
		}
	~thread_pool() {
		if (nthreads > 0) {
			done = true;
			std::unique_lock<std::mutex> lock(m);
			cv.notify_all();
			lock.unlock();
			for (std::thread &thread : threads) {
				thread.join();
			}
                }
        }

	template<typename Function, typename... Args>
	std::future<typename std::result_of<Function(Args...)>::type>
	submit(Function &&f, Args&&... args) {
		using result_type = typename std::result_of<Function(Args...)>::type;
		auto task = std::make_shared<std::packaged_task<result_type()>>(std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
		std::future<result_type> res(task->get_future());
                work_queue.push([task] { (*task)(); });
                std::unique_lock<std::mutex> lock(m);
                cv.notify_one();
                return res;
	}

        int size() {
                return nthreads;
        }

        int thread_id_to_int(std::thread::id id) {
                return thread_id[id];
        }

        template<typename T, typename Function>
        void parallel_for(T start, T end, T stride, Function &&f) {
                T range = end - start;
                T block_size = range / (nthreads);
                T block_start = start;
                T block_end = block_start + block_size;
                if (block_size == 0)
                        block_end = end;
                std::vector<std::future<void>> res;
                while (block_start < end) {
                        res.emplace_back(submit(f, block_start, block_end, stride));
                        block_start = block_end;
                        block_end = block_end + block_size;
                        if (block_end >= end)
                                block_end = end;
                }
                for (size_t i = 0; i < res.size(); i++)
                        res[i].get();
        }
};

#endif