File: TestStdAlgorithmsForEach.cpp

package info (click to toggle)
kokkos 4.7.01-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 16,636 kB
  • sloc: cpp: 223,676; sh: 2,446; makefile: 2,437; python: 91; fortran: 4; ansic: 2
file content (135 lines) | stat: -rw-r--r-- 4,755 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//@HEADER
// ************************************************************************
//
//                        Kokkos v. 4.0
//       Copyright (2022) National Technology & Engineering
//               Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <TestStdAlgorithmsCommon.hpp>
#include <algorithm>

namespace Test {
namespace stdalgos {
namespace ForEach {

namespace KE = Kokkos::Experimental;

template <class ViewType>
void test_for_each(const ViewType view) {
  using value_t           = typename ViewType::value_type;
  using view_host_space_t = Kokkos::View<value_t*, Kokkos::HostSpace>;

  view_host_space_t expected("for_each_expected", view.extent(0));
  compare_views(expected, view);

  const auto mod_functor = IncrementElementWiseFunctor<value_t>();

  // pass view, functor takes non-const ref
  KE::for_each("label", exespace(), view, mod_functor);
  std::for_each(KE::begin(expected), KE::end(expected), mod_functor);
  compare_views(expected, view);

  // pass iterators, functor takes non-const ref
  KE::for_each(exespace(), KE::begin(view), KE::end(view), mod_functor);
  std::for_each(KE::begin(expected), KE::end(expected), mod_functor);
  compare_views(expected, view);

  const auto non_mod_functor = NoOpNonMutableFunctor<value_t>();

  // pass view, functor takes const ref
  KE::for_each(exespace(), view, non_mod_functor);
  std::for_each(KE::begin(expected), KE::end(expected), non_mod_functor);
  compare_views(expected, view);

  // pass const iterators, functor takes const ref
  KE::for_each(exespace(), KE::cbegin(view), KE::cend(view), non_mod_functor);
  std::for_each(KE::begin(expected), KE::end(expected), non_mod_functor);
  compare_views(expected, view);

  const auto mod_lambda = KOKKOS_LAMBDA(value_t & i) { ++i; };

  // pass view, lambda takes non-const ref
  KE::for_each(exespace(), view, mod_lambda);
  std::for_each(KE::begin(expected), KE::end(expected), mod_lambda);
  compare_views(expected, view);

  // pass iterators, lambda takes non-const ref
  KE::for_each(exespace(), KE::begin(view), KE::end(view), mod_lambda);
  std::for_each(KE::begin(expected), KE::end(expected), mod_lambda);
  compare_views(expected, view);

  const auto non_mod_lambda = KOKKOS_LAMBDA(const value_t& i) { (void)i; };

  // pass view, lambda takes const ref
  KE::for_each(exespace(), view, non_mod_lambda);
  std::for_each(KE::cbegin(expected), KE::cend(expected), non_mod_lambda);
  compare_views(expected, view);

  // pass const iterators, lambda takes const ref
  KE::for_each(exespace(), KE::cbegin(view), KE::cend(view), non_mod_lambda);
  std::for_each(KE::cbegin(expected), KE::cend(expected), non_mod_lambda);
  compare_views(expected, view);
}

// std::for_each_n is C++17, so we cannot compare results directly
template <class ViewType>
void test_for_each_n(const ViewType view) {
  using value_t       = typename ViewType::value_type;
  const std::size_t n = view.extent(0);

  const auto non_mod_functor = NoOpNonMutableFunctor<value_t>();

  // pass const iterators, functor takes const ref
  ASSERT_EQ(KE::cbegin(view) + n,
            KE::for_each_n(exespace(), KE::cbegin(view), n, non_mod_functor));
  verify_values(value_t{0}, view);

  // pass view, functor takes const ref
  ASSERT_EQ(KE::begin(view) + n,
            KE::for_each_n(exespace(), view, n, non_mod_functor));
  verify_values(value_t{0}, view);

  // pass iterators, functor takes non-const ref
  const auto mod_functor = IncrementElementWiseFunctor<value_t>();
  ASSERT_EQ(KE::begin(view) + n,
            KE::for_each_n(exespace(), KE::begin(view), n, mod_functor));
  verify_values(value_t{1}, view);

  // pass view, functor takes non-const ref
  ASSERT_EQ(KE::begin(view) + n,
            KE::for_each_n("label", exespace(), view, n, mod_functor));
  verify_values(value_t{2}, view);
}

template <class Tag, class ValueType>
void run_all_scenarios() {
  for (const auto& scenario : default_scenarios) {
    {
      auto view = create_view<ValueType>(Tag{}, scenario.second, "for_each");
      test_for_each(view);
    }
    {
      auto view = create_view<ValueType>(Tag{}, scenario.second, "for_each_n");
      test_for_each_n(view);
    }
  }
}

TEST(std_algorithms_for_each_test, test) {
  run_all_scenarios<DynamicTag, double>();
  run_all_scenarios<StridedTwoTag, int>();
  run_all_scenarios<StridedThreeTag, unsigned>();
}

}  // namespace ForEach
}  // namespace stdalgos
}  // namespace Test