1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOS_TEST_SIMD_CONDITION_HPP
#define KOKKOS_TEST_SIMD_CONDITION_HPP
#include <Kokkos_SIMD.hpp>
#include <SIMDTesting_Utilities.hpp>
template <typename Abi, typename DataType>
inline void host_check_condition() {
if constexpr (is_simd_avail_v<DataType, Abi>) {
using simd_type = typename Kokkos::Experimental::basic_simd<DataType, Abi>;
using mask_type = typename simd_type::mask_type;
auto condition_op = [](mask_type const& mask, simd_type const& a,
simd_type const& b) {
return Kokkos::Experimental::condition(mask, a, b);
};
simd_type value_a(16);
simd_type value_b(20);
auto condition_result = condition_op(mask_type(false), value_a, value_b);
EXPECT_TRUE(all_of(condition_result == value_b));
condition_result = condition_op(mask_type(true), value_a, value_b);
EXPECT_TRUE(all_of(condition_result == value_a));
}
}
template <typename Abi, typename... DataTypes>
inline void host_check_condition_all_types(
Kokkos::Experimental::Impl::data_types<DataTypes...>) {
(host_check_condition<Abi, DataTypes>(), ...);
}
template <typename... Abis>
inline void host_check_condition_all_abis(
Kokkos::Experimental::Impl::abi_set<Abis...>) {
using DataTypes = Kokkos::Experimental::Impl::data_type_set;
(host_check_condition_all_types<Abis>(DataTypes()), ...);
}
template <typename Abi, typename DataType>
KOKKOS_INLINE_FUNCTION void device_check_condition() {
if constexpr (is_type_v<Kokkos::Experimental::basic_simd<DataType, Abi>>) {
using simd_type = typename Kokkos::Experimental::basic_simd<DataType, Abi>;
using mask_type = typename simd_type::mask_type;
kokkos_checker checker;
auto condition_op = [](mask_type const& mask, simd_type const& a,
simd_type const& b) {
return Kokkos::Experimental::condition(mask, a, b);
};
simd_type value_a(16);
simd_type value_b(20);
auto condition_result = condition_op(mask_type(false), value_a, value_b);
checker.truth(all_of(condition_result == value_b));
condition_result = condition_op(mask_type(true), value_a, value_b);
checker.truth(all_of(condition_result == value_a));
}
}
template <typename Abi, typename... DataTypes>
KOKKOS_INLINE_FUNCTION void device_check_condition_all_types(
Kokkos::Experimental::Impl::data_types<DataTypes...>) {
(device_check_condition<Abi, DataTypes>(), ...);
}
template <typename... Abis>
KOKKOS_INLINE_FUNCTION void device_check_condition_all_abis(
Kokkos::Experimental::Impl::abi_set<Abis...>) {
using DataTypes = Kokkos::Experimental::Impl::data_type_set;
(device_check_condition_all_types<Abis>(DataTypes()), ...);
}
class simd_device_condition_functor {
public:
KOKKOS_INLINE_FUNCTION void operator()(int) const {
device_check_condition_all_abis(
Kokkos::Experimental::Impl::device_abi_set());
}
};
TEST(simd, host_condition) {
host_check_condition_all_abis(Kokkos::Experimental::Impl::host_abi_set());
}
TEST(simd, device_condition) {
Kokkos::parallel_for(Kokkos::RangePolicy<Kokkos::IndexType<int>>(0, 1),
simd_device_condition_functor());
}
#endif
|