File: TestDynRankView_TeamScratch.hpp

package info (click to toggle)
kokkos 5.0.2-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 15,148 kB
  • sloc: cpp: 225,388; sh: 1,250; python: 78; makefile: 16; fortran: 4; ansic: 2
file content (66 lines) | stat: -rw-r--r-- 2,442 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#include <gtest/gtest.h>

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
import kokkos.dyn_rank_view;
#else
#include <Kokkos_Core.hpp>
#include <Kokkos_DynRankView.hpp>
#endif

namespace {

void test_dyn_rank_view_team_scratch() {
  using execution_space = TEST_EXECSPACE;
  using memory_space    = execution_space::scratch_memory_space;
  using drv_type        = Kokkos::DynRankView<int, memory_space>;
  using policy_type     = Kokkos::TeamPolicy<execution_space>;
  using team_type       = policy_type::member_type;

  size_t N0 = 10, N1 = 4, N2 = 3;
  size_t shmem_size = drv_type::shmem_size(N0, N1, N2);
  ASSERT_GE(shmem_size, N0 * N1 * N2 * sizeof(int));

  Kokkos::View<unsigned, execution_space, Kokkos::MemoryTraits<Kokkos::Atomic>>
      errors("errors");
  auto policy = policy_type(1, Kokkos::AUTO)
                    .set_scratch_size(0, Kokkos::PerTeam(shmem_size));
  Kokkos::parallel_for(
      policy, KOKKOS_LAMBDA(const team_type& team) {
        drv_type scr(team.team_scratch(0), N0, N1, N2);
        // Control that the code ran at all
        if (scr.rank() != 3) errors() |= 1u;
        if (scr.extent(0) != N0) errors() |= 2u;
        if (scr.extent(1) != N1) errors() |= 4u;
        if (scr.extent(2) != N2) errors() |= 8u;
        Kokkos::parallel_for(
            Kokkos::TeamThreadMDRange(team, N0, N1, N2),
            [=](int i, int j, int k) { scr(i, j, k) = i * 100 + j * 10 + k; });
        team.team_barrier();
        Kokkos::parallel_for(Kokkos::TeamThreadMDRange(team, N0, N1, N2),
                             [=](int i, int j, int k) {
                               if (scr(i, j, k) != i * 100 + j * 10 + k)
                                 errors() |= 16u;
                             });
        errors() |= 256u;
      });
  unsigned h_errors = 0;
  Kokkos::deep_copy(h_errors, errors);

  ASSERT_EQ((h_errors & 1u), 0u) << "Rank mismatch";
  ASSERT_EQ((h_errors & 2u), 0u) << "extent 0 mismatch";
  ASSERT_EQ((h_errors & 4u), 0u) << "extent 1 mismatch";
  ASSERT_EQ((h_errors & 8u), 0u) << "extent 2 mismatch";
  ASSERT_EQ((h_errors & 16u), 0u) << "data access incorrect";
  ASSERT_EQ(h_errors, 256u);
}

TEST(TEST_CATEGORY, dyn_rank_view_team_scratch) {
  test_dyn_rank_view_team_scratch();
}

}  // namespace