File: TestStdAlgorithmsHelperFunctors.hpp

package info (click to toggle)
kokkos 5.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 15,140 kB
  • sloc: cpp: 225,293; sh: 1,250; python: 78; makefile: 16; fortran: 4; ansic: 2
file content (164 lines) | stat: -rw-r--r-- 4,178 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOS_ALGORITHMS_UNITTESTS_TEST_STD_ALGOS_HELPERS_FUNCTORS_HPP
#define KOKKOS_ALGORITHMS_UNITTESTS_TEST_STD_ALGOS_HELPERS_FUNCTORS_HPP

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
#else
#include <Kokkos_Core.hpp>
#endif
#include <type_traits>

namespace Test {
namespace stdalgos {

template <class ViewTypeFrom, class ViewTypeTo>
struct CopyFunctor {
  ViewTypeFrom m_view_from;
  ViewTypeTo m_view_to;

  CopyFunctor() = delete;

  CopyFunctor(const ViewTypeFrom view_from, const ViewTypeTo view_to)
      : m_view_from(view_from), m_view_to(view_to) {}

  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const { m_view_to(i) = m_view_from(i); }
};

template <class ViewTypeFrom, class ViewTypeTo>
struct CopyFunctorRank2 {
  ViewTypeFrom m_view_from;
  ViewTypeTo m_view_to;

  CopyFunctorRank2() = delete;

  CopyFunctorRank2(const ViewTypeFrom view_from, const ViewTypeTo view_to)
      : m_view_from(view_from), m_view_to(view_to) {}

  KOKKOS_INLINE_FUNCTION
  void operator()(int k) const {
    const auto i    = k / m_view_from.extent(1);
    const auto j    = k % m_view_from.extent(1);
    m_view_to(i, j) = m_view_from(i, j);
  }
};

template <class ItTypeFrom, class ViewTypeTo>
struct CopyFromIteratorFunctor {
  ItTypeFrom m_it_from;
  ViewTypeTo m_view_to;

  CopyFromIteratorFunctor(const ItTypeFrom it_from, const ViewTypeTo view_to)
      : m_it_from(it_from), m_view_to(view_to) {}

  KOKKOS_INLINE_FUNCTION
  void operator()(int) const { m_view_to() = *m_it_from; }
};

template <class ValueType>
struct IncrementElementWiseFunctor {
  KOKKOS_INLINE_FUNCTION
  void operator()(ValueType& val) const { ++val; }
};

template <class ViewType>
struct FillZeroFunctor {
  ViewType m_view;

  KOKKOS_INLINE_FUNCTION
  void operator()(int index) const {
    m_view(index) = static_cast<typename ViewType::value_type>(0);
  }

  KOKKOS_INLINE_FUNCTION
  FillZeroFunctor(ViewType viewIn) : m_view(viewIn) {}
};

template <class ValueType>
struct NoOpNonMutableFunctor {
  KOKKOS_INLINE_FUNCTION
  void operator()(const ValueType& val) const { (void)val; }
};

template <class ViewType>
struct AssignIndexFunctor {
  ViewType m_view;

  AssignIndexFunctor(ViewType view) : m_view(view) {}

  KOKKOS_INLINE_FUNCTION
  void operator()(int i) const { m_view(i) = typename ViewType::value_type(i); }
};

template <class ValueType>
struct IsEvenFunctor {
  static_assert(std::is_integral_v<ValueType>,
                "IsEvenFunctor uses operator%, so ValueType must be int");

  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType val) const { return (val % 2 == 0); }
};

template <class ValueType>
struct IsPositiveFunctor {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType val) const { return (val > 0); }
};

template <class ValueType>
struct IsNegativeFunctor {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType val) const { return (val < 0); }
};

template <class ValueType>
struct NotEqualsZeroFunctor {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType val) const { return val != 0; }
};

template <class ValueType>
struct EqualsValFunctor {
  const ValueType m_value;

  EqualsValFunctor(ValueType value) : m_value(value) {}

  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType val) const { return val == m_value; }
};

template <class ValueType1, class ValueType2 = ValueType1>
struct CustomLessThanComparator {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType1& a, const ValueType2& b) const {
    return a < b;
  }

  KOKKOS_INLINE_FUNCTION
  CustomLessThanComparator() {}
};

template <class ValueType>
struct CustomEqualityComparator {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType& a, const ValueType& b) const {
    return a == b;
  }
};

template <class ValueType1, class ValueType2 = ValueType1>
struct IsEqualFunctor {
  KOKKOS_INLINE_FUNCTION
  bool operator()(const ValueType1& a, const ValueType2& b) const {
    return (a == b);
  }
};

}  // namespace stdalgos
}  // namespace Test

#endif