File: stateful_operator.cu

package info (click to toggle)
libthrust 1.17.2-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,900 kB
  • sloc: ansic: 29,519; cpp: 23,989; python: 1,421; sh: 811; perl: 460; makefile: 112
file content (61 lines) | stat: -rw-r--r-- 1,614 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2014

#include <async/test_policy_overloads.h>

#include <async/inclusive_scan/mixin.h>

namespace
{

// Custom binary operator for scan:
template <typename T>
struct stateful_operator
{
  T offset;

  __host__ __device__ T operator()(T v1, T v2) { return v1 + v2 + offset; }
};

// Postfix args overload definition that uses a stateful custom binary operator
template <typename value_type>
struct use_stateful_operator
{
  using postfix_args_type = std::tuple<       // Single overload:
    std::tuple<stateful_operator<value_type>> // bin_op
    >;

  static postfix_args_type generate_postfix_args()
  {
    return postfix_args_type{
      std::make_tuple(stateful_operator<value_type>{value_type{2}})};
  }
};

template <typename value_type>
struct invoker
    : testing::async::mixin::input::device_vector<value_type>
    , testing::async::mixin::output::device_vector<value_type>
    , use_stateful_operator<value_type>
    , testing::async::inclusive_scan::mixin::invoke_reference::host_synchronous<
        value_type>
    , testing::async::inclusive_scan::mixin::invoke_async::simple
    , testing::async::mixin::compare_outputs::assert_almost_equal_if_fp_quiet
{
  static std::string description() { return "scan with stateful operator"; }
};

} // namespace

template <typename T>
struct test_stateful_operator
{
  void operator()(std::size_t num_values) const
  {
    testing::async::test_policy_overloads<invoker<T>>::run(num_values);
  }
};
DECLARE_GENERIC_SIZED_UNITTEST_WITH_TYPES(test_stateful_operator, NumericTypes);

#endif // C++14