File: task_group_extensions_reduction.cpp

package info (click to toggle)
onetbb 2022.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 19,440 kB
  • sloc: cpp: 129,228; ansic: 9,745; python: 808; xml: 183; objc: 176; makefile: 66; sh: 66; awk: 41; javascript: 37
file content (98 lines) | stat: -rw-r--r-- 3,091 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
    Copyright (c) 2025 UXL Foundation Contributors

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
*/

#include <cstdint>
#include <iostream>

static constexpr std::size_t serial_threshold = 16;

/*begin_task_group_extensions_reduction_example*/
#define TBB_PREVIEW_TASK_GROUP_EXTENSIONS 1
#include "oneapi/tbb/task_group.h"

struct reduce_task {

    struct join_task {
        void operator()() const {
            result = *left + *right;
        }

        std::size_t& result;
        std::unique_ptr<std::size_t> left;
        std::unique_ptr<std::size_t> right;
    };

    tbb::task_handle operator()() const {
        tbb::task_handle next_task;

        std::size_t size = end - begin;
        if (size < serial_threshold) {
            // Perform serial reduction
            for (std::size_t i = begin; i < end; ++i) {
                result += i;
            }
        } else {
            // The range is too large to process directly
            // Divide it into smaller segments for parallel execution
            std::size_t middle = begin + size / 2;

            auto left_result = std::make_unique<std::size_t>(0);
            auto right_result = std::make_unique<std::size_t>(0);

            
            tbb::task_handle left_leaf = tg.defer(reduce_task{begin, middle, *left_result, tg});
            tbb::task_handle right_leaf = tg.defer(reduce_task{middle, end, *right_result, tg});

            tbb::task_handle join = tg.defer(join_task{result, std::move(left_result), std::move(right_result)});

            tbb::task_group::set_task_order(left_leaf, join);
            tbb::task_group::set_task_order(right_leaf, join);

            tbb::task_group::transfer_this_task_completion_to(join);

            // Save the left leaf for further bypassing
            next_task = std::move(left_leaf);

            tg.run(std::move(right_leaf));
            tg.run(std::move(join));
        }

        return next_task;
    }

    std::size_t begin;
    std::size_t end;
    std::size_t& result;
    tbb::task_group& tg;
};

std::size_t calculate_parallel_sum(std::size_t begin, std::size_t end) {
    tbb::task_group tg;

    std::size_t reduce_result = 0;
    tg.run_and_wait(reduce_task{begin, end, reduce_result, tg});

    return reduce_result;
}
/*end_task_group_extensions_reduction_example*/

int main() {
    constexpr std::size_t N = 10000;
    std::size_t serial_sum = N * (N - 1) / 2;
    std::size_t parallel_sum = calculate_parallel_sum(0, N);

    if (serial_sum != parallel_sum) std::cerr << "Incorrect reduction result" << std::endl;
}