File: horizontal_diffusion_fused.cpp

package info (click to toggle)
gridtools 2.3.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 29,480 kB
  • sloc: cpp: 228,792; python: 17,561; javascript: 9,164; ansic: 4,101; sh: 850; makefile: 231; f90: 201
file content (102 lines) | stat: -rw-r--r-- 3,442 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
 * GridTools
 *
 * Copyright (c) 2014-2023, ETH Zurich
 * All rights reserved.
 *
 * Please, refer to the LICENSE file in the root directory.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <gridtools/stencil/cartesian.hpp>

#include <stencil_select.hpp>
#include <test_environment.hpp>

#include "horizontal_diffusion_repository.hpp"

namespace {
    using namespace gridtools;
    using namespace stencil;
    using namespace cartesian;

    struct lap_function {
        using out = inout_accessor<0>;
        using in = in_accessor<1, extent<-1, 1, -1, 1>>;

        using param_list = make_param_list<out, in>;

        template <typename Evaluation>
        GT_FUNCTION static void apply(Evaluation eval) {
            using float_t = std::decay_t<decltype(eval(out()))>;
            eval(out()) =
                float_t{4} * eval(in()) - (eval(in(1, 0)) + eval(in(0, 1)) + eval(in(-1, 0)) + eval(in(0, -1)));
        }
    };

    struct flx_function {
        using out = inout_accessor<0>;
        using in = in_accessor<1, extent<-1, 2, -1, 1>>;

        using param_list = make_param_list<out, in>;

        template <typename Evaluation>
        GT_FUNCTION static void apply(Evaluation eval) {
            auto lap_hi = call<lap_function>::with(eval, in(1, 0));
            auto lap_lo = call<lap_function>::with(eval, in(0, 0));
            auto flx = lap_hi - lap_lo;
            eval(out()) = flx * (eval(in(1, 0)) - eval(in(0, 0))) > 0 ? 0 : flx;
        }
    };

    struct fly_function {
        using out = inout_accessor<0>;
        using in = in_accessor<1, extent<-1, 1, -1, 2>>;

        using param_list = make_param_list<out, in>;

        template <typename Evaluation>
        GT_FUNCTION static void apply(Evaluation eval) {
            auto lap_hi = call<lap_function>::with(eval, in(0, 1));
            auto lap_lo = call<lap_function>::with(eval, in(0, 0));
            auto fly = lap_hi - lap_lo;
            eval(out()) = fly * (eval(in(0, 1)) - eval(in(0, 0))) > 0 ? 0 : fly;
        }
    };

    struct out_function {
        using out = inout_accessor<0>;
        using in = in_accessor<1, extent<-2, 2, -2, 2>>;
        using coeff = in_accessor<2>;

        using param_list = make_param_list<out, in, coeff>;

        template <typename Evaluation>
        GT_FUNCTION static void apply(Evaluation eval) {
            auto flx_hi = call<flx_function>::with(eval, in(0, 0));
            auto flx_lo = call<flx_function>::with(eval, in(-1, 0));

            auto fly_hi = call<fly_function>::with(eval, in(0, 0));
            auto fly_lo = call<fly_function>::with(eval, in(0, -1));

            eval(out()) = eval(in()) - eval(coeff()) * (flx_hi - flx_lo + fly_hi - fly_lo);
        }
    };

    GT_REGRESSION_TEST(horizontal_diffusion_fused, test_environment<2>, stencil_backend_t) {
        auto out = TypeParam::make_storage();

        horizontal_diffusion_repository repo(TypeParam::d(0), TypeParam::d(1), TypeParam::d(2));

        auto comp = [&,
                        grid = TypeParam::make_grid(),
                        in = TypeParam::make_storage(repo.in),
                        coeff = TypeParam::make_storage(repo.coeff)] {
            run_single_stage(out_function(), stencil_backend_t(), grid, out, in, coeff);
        };

        comp();
        TypeParam::verify(repo.out, out);
        TypeParam::benchmark("horizontal_diffusion_fused", comp);
    }
} // namespace