File: test_fn_executor.cpp

package info (click to toggle)
gridtools 2.3.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 29,480 kB
  • sloc: cpp: 228,792; python: 17,561; javascript: 9,164; ansic: 4,101; sh: 850; makefile: 231; f90: 201
file content (107 lines) | stat: -rw-r--r-- 3,533 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/*
 * GridTools
 *
 * Copyright (c) 2014-2023, ETH Zurich
 * All rights reserved.
 *
 * Please, refer to the LICENSE file in the root directory.
 * SPDX-License-Identifier: BSD-3-Clause
 */
#include <gridtools/fn/executor.hpp>

#include <gtest/gtest.h>

#include <gridtools/fn/backend/naive.hpp>
#include <gridtools/fn/column_stage.hpp>

namespace gridtools::fn {
    namespace {
        using namespace literals;
        using sid::property;

        template <int I>
        using int_t = integral_constant<int, I>;

        struct stencil {
            GT_FUNCTION constexpr auto operator()() const {
                return [](auto const &iter) { return 2 * *iter; };
            }
        };

        struct fwd_sum_scan : fwd {
            static GT_FUNCTION constexpr auto body() {
                return scan_pass([](auto acc, auto const &iter) { return acc + *iter; }, [](auto acc) { return acc; });
            }
        };

        struct bwd_sum_scan : bwd {
            static GT_FUNCTION constexpr auto body() {
                return scan_pass([](auto acc, auto const &iter) { return acc + *iter; }, [](auto acc) { return acc; });
            }
        };

        struct make_iterator_mock {
            GT_FUNCTION auto operator()() const {
                return [](auto tag, auto const &ptr, auto const &) { return at_key<decltype(tag)>(ptr); };
            }
        };

        TEST(stencil_executor, smoke) {
            using backend_t = backend::naive;
            auto domain = hymap::keys<int_t<0>, int_t<1>>::make_values(2_c, 3_c);

            auto alloc = tmp_allocator(backend_t());
            int a[2][3] = {}, b[2][3] = {}, c[2][3];
            for (int i = 0; i < 2; ++i)
                for (int j = 0; j < 3; ++j)
                    c[i][j] = 3 * i + j;

            make_stencil_executor(backend_t(), domain, std::tuple<>(), make_iterator_mock())
                .arg(a)
                .arg(b)
                .arg(c)
                .assign(1_c, stencil(), 2_c)
                .assign(0_c, stencil(), 1_c)
                .execute();

            for (int i = 0; i < 2; ++i)
                for (int j = 0; j < 3; ++j) {
                    EXPECT_EQ(a[i][j], (3 * i + j) * 4);
                    EXPECT_EQ(b[i][j], (3 * i + j) * 2);
                    EXPECT_EQ(c[i][j], (3 * i + j) * 1);
                }
        }

        TEST(vertical_executor, smoke) {
            using backend_t = backend::naive;
            auto domain = hymap::keys<int_t<0>, int_t<1>>::make_values(2_c, 3_c);

            int a[2][3] = {}, b[2][3] = {}, c[2][3];
            for (int i = 0; i < 2; ++i)
                for (int j = 0; j < 3; ++j)
                    c[i][j] = 3 * i + j;

            make_vertical_executor<int_t<1>>(backend_t(), domain, std::tuple<>(), make_iterator_mock())
                .arg(a)
                .arg(b)
                .arg(c)
                .assign(1_c, fwd_sum_scan(), 42, 2_c)
                .assign(0_c, bwd_sum_scan(), 8, 1_c)
                .execute();

            for (int i = 0; i < 2; ++i) {
                int res = 42;
                for (int j = 0; j < 3; ++j) {
                    EXPECT_EQ(c[i][j], 3 * i + j);
                    res += c[i][j];
                    EXPECT_EQ(b[i][j], res);
                }
                res = 8;
                for (int j = 2; j >= 0; --j) {
                    res += b[i][j];
                    EXPECT_EQ(a[i][j], res);
                }
            }
        }
    } // namespace
} // namespace gridtools::fn