// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOS_TEST_SIMD_GENERATOR_CTORS_HPP
#define KOKKOS_TEST_SIMD_GENERATOR_CTORS_HPP

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.simd;
import kokkos.simd_impl;
#else
#include <Kokkos_SIMD.hpp>
#endif
#include <SIMDTesting_Utilities.hpp>

template <typename Abi, typename DataType>
inline void host_check_gen_ctor() {
  if constexpr (is_simd_avail_v<DataType, Abi>) {
    using simd_type = Kokkos::Experimental::basic_simd<DataType, Abi>;
    using mask_type = typename simd_type::mask_type;
    constexpr std::size_t lanes = simd_type::size();

    DataType init[lanes];
    DataType expected[lanes];
    bool init_mask[lanes];

    for (std::size_t i = 0; i < lanes; ++i) {
      init_mask[i] = (i % 2 == 0);
      init[i]      = i + 1;
      expected[i]  = (init_mask[i]) ? init[i] * 10 : init[i];
    }

    simd_type rhs = Kokkos::Experimental::simd_unchecked_load<simd_type>(
        init, Kokkos::Experimental::simd_flag_default);
    simd_type blend = Kokkos::Experimental::simd_unchecked_load<simd_type>(
        expected, Kokkos::Experimental::simd_flag_default);

#if !(defined(KOKKOS_ENABLE_CUDA) && defined(KOKKOS_COMPILER_MSVC))
    if constexpr (std::is_same_v<Abi, Kokkos::Experimental::simd_abi::scalar>) {
      simd_type basic(KOKKOS_LAMBDA(std::size_t i) { return init[i]; });
      host_check_equality(basic, rhs, lanes);

      simd_type lhs(KOKKOS_LAMBDA(std::size_t i) { return init[i] * 10; });
      mask_type mask(KOKKOS_LAMBDA(std::size_t i) { return init_mask[i]; });
      simd_type result(
          KOKKOS_LAMBDA(std::size_t i) { return (mask[i]) ? lhs[i] : rhs[i]; });

      host_check_equality(blend, result, lanes);
    } else {
      simd_type basic([=](std::size_t i) { return init[i]; });
      host_check_equality(basic, rhs, lanes);

      simd_type lhs([=](std::size_t i) { return init[i] * 10; });
      mask_type mask([=](std::size_t i) { return init_mask[i]; });
      simd_type result(
          [=](std::size_t i) { return (mask[i]) ? lhs[i] : rhs[i]; });
      host_check_equality(blend, result, lanes);
    }
#endif
  }
}

template <typename Abi, typename... DataTypes>
inline void host_check_gen_ctors_all_types(
    Kokkos::Experimental::Impl::data_types<DataTypes...>) {
  (host_check_gen_ctor<Abi, DataTypes>(), ...);
}

template <typename... Abis>
inline void host_check_gen_ctors_all_abis(
    Kokkos::Experimental::Impl::abi_set<Abis...>) {
  using DataTypes = Kokkos::Experimental::Impl::data_type_set;
  (host_check_gen_ctors_all_types<Abis>(DataTypes()), ...);
}

template <typename Abi, typename DataType>
KOKKOS_INLINE_FUNCTION void device_check_gen_ctor() {
  if constexpr (is_type_v<Kokkos::Experimental::basic_simd<DataType, Abi>>) {
    using simd_type = Kokkos::Experimental::basic_simd<DataType, Abi>;
    using mask_type = typename simd_type::mask_type;
    constexpr std::size_t lanes = simd_type::size();

    DataType init[lanes];
    DataType expected[lanes];
    bool init_mask[lanes];

    for (std::size_t i = 0; i < lanes; ++i) {
      if (i % 3 == 0) {
        init_mask[i] = true;
      } else {
        init_mask[i] = false;
      }
      init[i]     = 7;
      expected[i] = (init_mask[i]) ? init[i] * 9 : init[i];
    }
    mask_type mask(KOKKOS_LAMBDA(std::size_t i) { return init_mask[i]; });

    simd_type basic(KOKKOS_LAMBDA(std::size_t i) { return init[i]; });
    simd_type rhs = Kokkos::Experimental::simd_unchecked_load<simd_type>(
        init, Kokkos::Experimental::simd_flag_default);
    device_check_equality(basic, rhs, lanes);

    simd_type lhs(KOKKOS_LAMBDA(std::size_t i) { return init[i] * 9; });
    simd_type result(
        KOKKOS_LAMBDA(std::size_t i) { return (mask[i]) ? lhs[i] : rhs[i]; });

    simd_type blend = Kokkos::Experimental::simd_unchecked_load<simd_type>(
        expected, Kokkos::Experimental::simd_flag_default);
    device_check_equality(result, blend, lanes);
  }
}

template <typename Abi, typename... DataTypes>
KOKKOS_INLINE_FUNCTION void device_check_gen_ctors_all_types(
    Kokkos::Experimental::Impl::data_types<DataTypes...>) {
  (device_check_gen_ctor<Abi, DataTypes>(), ...);
}

template <typename... Abis>
KOKKOS_INLINE_FUNCTION void device_check_gen_ctors_all_abis(
    Kokkos::Experimental::Impl::abi_set<Abis...>) {
  using DataTypes = Kokkos::Experimental::Impl::data_type_set;
  (device_check_gen_ctors_all_types<Abis>(DataTypes()), ...);
}

class simd_device_gen_ctor_functor {
 public:
  KOKKOS_INLINE_FUNCTION void operator()(int) const {
    device_check_gen_ctors_all_abis(
        Kokkos::Experimental::Impl::device_abi_set());
  }
};

TEST(simd, host_gen_ctors) {
  host_check_gen_ctors_all_abis(Kokkos::Experimental::Impl::host_abi_set());
}

TEST(simd, device_gen_ctors) {
  Kokkos::parallel_for(1, simd_device_gen_ctor_functor());
}

#endif
