1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
// -*- C++ -*-
// -*-===----------------------------------------------------------------------===//
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//
#ifndef _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
#define _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
#include "parallel_invoke.h"
namespace __pstl
{
namespace __omp_backend
{
template <typename _Index>
_Index
__split(_Index __m)
{
_Index __k = 1;
while (2 * __k < __m)
__k *= 2;
return __k;
}
template <typename _Index, typename _Tp, typename _Rp, typename _Cp>
void
__upsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Rp __reduce, _Cp __combine)
{
if (__m == 1)
__r[0] = __reduce(__i * __tilesize, __lastsize);
else
{
_Index __k = __split(__m);
__omp_backend::__parallel_invoke_body(
[=] { __omp_backend::__upsweep(__i, __k, __tilesize, __r, __tilesize, __reduce, __combine); },
[=] {
__omp_backend::__upsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, __reduce, __combine);
});
if (__m == 2 * __k)
__r[__m - 1] = __combine(__r[__k - 1], __r[__m - 1]);
}
}
template <typename _Index, typename _Tp, typename _Cp, typename _Sp>
void
__downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine,
_Sp __scan)
{
if (__m == 1)
__scan(__i * __tilesize, __lastsize, __initial);
else
{
const _Index __k = __split(__m);
__omp_backend::__parallel_invoke_body(
[=] { __omp_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, __scan); },
// Assumes that __combine never throws.
// TODO: Consider adding a requirement for user functors to be constant.
[=, &__combine]
{
__omp_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize,
__combine(__initial, __r[__k - 1]), __combine, __scan);
});
}
}
template <typename _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp,
typename _Ap>
void
__parallel_strict_scan_body(_Index __n, _Tp __initial, _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
{
_Index __p = omp_get_num_threads();
const _Index __slack = 4;
_Index __tilesize = (__n - 1) / (__slack * __p) + 1;
_Index __m = (__n - 1) / __tilesize;
__buffer<_Tp> __buf(__m + 1);
_Tp* __r = __buf.get();
__omp_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, __combine);
std::size_t __k = __m + 1;
_Tp __t = __r[__k - 1];
while ((__k &= __k - 1))
{
__t = __combine(__r[__k - 1], __t);
}
__apex(__combine(__initial, __t));
__omp_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __initial,
__combine, __scan);
}
template <class _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp, typename _Ap>
void
__parallel_strict_scan(__pstl::__internal::__openmp_backend_tag, _ExecutionPolicy&&, _Index __n, _Tp __initial,
_Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
{
if (__n <= __default_chunk_size)
{
_Tp __sum = __initial;
if (__n)
{
__sum = __combine(__sum, __reduce(_Index(0), __n));
}
__apex(__sum);
if (__n)
{
__scan(_Index(0), __n, __initial);
}
return;
}
if (omp_in_parallel())
{
__pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
__scan, __apex);
}
else
{
_PSTL_PRAGMA(omp parallel)
_PSTL_PRAGMA(omp single nowait)
{
__pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
__scan, __apex);
}
}
}
} // namespace __omp_backend
} // namespace __pstl
#endif // _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
|