File: reduction_heuristic.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (259 lines) | stat: -rw-r--r-- 11,291 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
#pragma once

#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>

#include <sstream>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// Parameters of the reduction heuristic to describe the optimial schedule.
// Warning: equal operator is intended for use in caching the kernel associated
// with these reduction parameters. It does not check if the launch parameters
// are equivelent!
class ReductionParams : public HeuristicParams {
 public:
  // Reducing inner most dimension?
  bool fastest_dim = false;

  // Store input in shared memory or registers to reduce global memory reads
  bool persistent_kernel = false;

  // Project persistent buffers back to inputs to reduce persistent buffer size
  bool project_persistent_buffers = false;

  // Are we treating the scheduling as 3 dimensional, can be useful for patterns
  // like [reduction, iteration, reduction].
  bool schedule_3D = false;

  // For outer reductions we may want to swap the gdimx and gdimy bindings to
  // amortize the cost of the final cleanup in grid reductions.
  bool flip_grid = false;

  // Inner Reduction Domain:

  // Reduce across the block?
  bool cross_block_inner_reduction = false;
  // Reduce across the grid?
  bool cross_grid_inner_reduction = false;
  // Unrolling/Vectorization factor for inner reduction dimension
  int64_t unroll_factor_inner_reduction = 1;
  // vectorize instead of unroll
  bool vectorize_inner_reduction = false;
  // Split grid dim for iteration axis in case it's too large for cuda
  bool split_grid_dim_inner_reduction = false;
  // Pad inner dimension to nearest warp
  bool pad_inner_reduction_to_warp = false;
  // Register persistent buffer size in inner dimension
  int64_t batches_per_block_inner_reduction = 1;

  // Which block parallel dimension should be used for the inner reduction.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType block_dim_inner_reduction = ParallelType::Serial;
  // Which grid parallel dimension should be used for the inner reduction.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType grid_dim_inner_reduction = ParallelType::Serial;

  // Iteration Domain:

  // Perform multiple reductions per block?
  bool multiple_reds_per_blk = false;
  // Unrolling/Vectorization factor for iteration dimension
  int64_t unroll_factor_iter_dom = 1;
  // vectorize instead of unroll
  bool vectorize_iter_dom = false;
  // Split grid dim for iteration axis in case it's too large for cuda
  bool split_grid_dim_iter_dom = false;

  // Which block parallel dimension should be used for the iter domain.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType block_dim_iter_dom = ParallelType::Serial;
  // Which grid parallel dimension should be used for the iter domain.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType grid_dim_iter_dom = ParallelType::Serial;

  // Outer Reduction Domain if 3D Scheduled:

  // Reduce across the block?
  bool cross_block_outer_reduction = false;
  // Reduce across the grid?
  bool cross_grid_outer_reduction = false;
  // Split grid dim for iteration axis in case it's too large for cuda
  bool split_grid_dim_outer_reduction = false;
  // Register persistent buffer size in outer dimension
  int64_t batches_per_block_outer_reduction = 1;
  // Unrolling/Vectorization factor for outer reduction factor
  int64_t unroll_factor_outer_reduction = 1;

  // Which block parallel dimension should be used for the outer reduction.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType block_dim_outer_reduction = ParallelType::Serial;
  // Which grid parallel dimension should be used for the outer reduction.
  // !!WARNING!! Convenience method, this be unique based on non-parallel type
  // parameters, not used for equivalence/hashing.
  ParallelType grid_dim_outer_reduction = ParallelType::Serial;

  bool isUnrolled() const {
    return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 ||
        unroll_factor_outer_reduction > 1;
  }

 public:
  using HeuristicParams::HeuristicParams;

  // Warning: Does not check launch parameters!
  bool sameAs(
      const std::shared_ptr<HeuristicParams>& other_base) const override {
    auto other_casted = std::dynamic_pointer_cast<ReductionParams>(other_base);
    if (other_casted == nullptr) {
      return false;
    }
    const ReductionParams& other = *other_casted;
    bool attr_equal = other.fastest_dim == fastest_dim &&
        other.persistent_kernel == persistent_kernel &&
        other.project_persistent_buffers == project_persistent_buffers &&
        other.schedule_3D == schedule_3D && other.flip_grid == flip_grid &&
        other.cross_block_inner_reduction == cross_block_inner_reduction &&
        other.cross_grid_inner_reduction == cross_grid_inner_reduction &&
        other.unroll_factor_inner_reduction == unroll_factor_inner_reduction &&
        other.vectorize_inner_reduction == vectorize_inner_reduction &&
        other.split_grid_dim_inner_reduction ==
            split_grid_dim_inner_reduction &&
        other.pad_inner_reduction_to_warp == pad_inner_reduction_to_warp &&
        other.batches_per_block_inner_reduction ==
            batches_per_block_inner_reduction &&
        other.multiple_reds_per_blk == multiple_reds_per_blk &&
        other.unroll_factor_iter_dom == unroll_factor_iter_dom &&
        other.vectorize_iter_dom == vectorize_iter_dom &&
        other.split_grid_dim_iter_dom == split_grid_dim_iter_dom &&
        other.cross_block_outer_reduction == cross_block_outer_reduction &&
        other.cross_grid_outer_reduction == cross_grid_outer_reduction &&
        other.unroll_factor_outer_reduction == unroll_factor_outer_reduction &&
        other.split_grid_dim_outer_reduction ==
            split_grid_dim_outer_reduction &&
        other.batches_per_block_outer_reduction ==
            batches_per_block_outer_reduction;
    return attr_equal;
  }

  std::string toString() const override {
    std::stringstream ss;
    ss << "\n===== Reduction Parameters ========\n"
       << (tag == "" ? "" : "Tag: ") << tag << "\n"
       << (fastest_dim ? "Red On Fastest Dim\n" : "Red On Slow Dim\n")
       << (persistent_kernel ? "Persistent Kernel\n" : "")
       << (project_persistent_buffers ? "Project Persistent Buffers\n" : "");
    if (batches_per_block_inner_reduction > 1 || persistent_kernel) {
      ss << "Batches per block: " << batches_per_block_inner_reduction << "\n";
    }

    if (schedule_3D) {
      ss << "3D Schedule\n"
         << "Outer Reduction: ";
      if (cross_block_outer_reduction) {
        ss << "cross block - " << block_dim_outer_reduction << " / ";
      }
      if (cross_grid_outer_reduction) {
        ss << "cross grid - " << grid_dim_outer_reduction << " / ";
        ss << (split_grid_dim_outer_reduction ? "split grid dim / " : "");
      }

      ss << (unroll_factor_outer_reduction > 1 ? "unroll / " : "");
      if (unroll_factor_outer_reduction > 1) {
        ss << "factor " << unroll_factor_outer_reduction << " ";
      }

      if (batches_per_block_outer_reduction > 1 || persistent_kernel) {
        ss << "persistent batch - " << batches_per_block_outer_reduction;
      }
    }

    ss << "\nIteration Domain: ";

    if (grid_dim_iter_dom != ParallelType::Serial) {
      ss << grid_dim_iter_dom << " / "
         << (split_grid_dim_iter_dom ? "split grid dimension / " : "");
    }
    if (block_dim_iter_dom != ParallelType::Serial) {
      ss << block_dim_iter_dom << " / ";
    }
    ss << (multiple_reds_per_blk ? "multiple reductions per block / " : "")
       << (vectorize_iter_dom ? "vectorize / " : "")
       << (unroll_factor_iter_dom > 1 && !vectorize_iter_dom ? "unroll / "
                                                             : "");
    if (unroll_factor_iter_dom > 1) {
      ss << "factor " << unroll_factor_iter_dom;
    }

    ss << "\nInner Reduction Domain: ";

    if (cross_block_inner_reduction) {
      ss << "cross block - " << block_dim_inner_reduction << " / ";
      ss << (pad_inner_reduction_to_warp ? " pad to warp / " : "");
    }
    if (cross_grid_inner_reduction) {
      ss << "cross grid - " << grid_dim_inner_reduction << " / ";
      ss << (split_grid_dim_inner_reduction ? "split grid dim / " : "");
    }
    if (batches_per_block_inner_reduction > 1 || persistent_kernel) {
      ss << "persistent batch - " << batches_per_block_inner_reduction << " / ";
    }
    ss << (cross_grid_inner_reduction && split_grid_dim_inner_reduction
               ? "split grid dimension / "
               : "")
       << (vectorize_inner_reduction ? "vectorize / " : "")
       << (unroll_factor_inner_reduction > 1 && !vectorize_inner_reduction
               ? "unroll / "
               : "");
    if (unroll_factor_inner_reduction > 1) {
      ss << "factor " << unroll_factor_inner_reduction;
    }

    ss << "\n" << lparams.toString() << "\n";
    ss << "====================================\n";
    return ss.str();
  }

  // Warning: Hash is not based on launch parameters!
  size_t hash() const override {
    constexpr size_t bits = sizeof(std::size_t) * 8;
    size_t attr_hash = static_cast<size_t>(fastest_dim) << (bits - 1) ^
        static_cast<size_t>(persistent_kernel) << (bits - 2) ^
        static_cast<size_t>(project_persistent_buffers) << (bits - 3) ^
        static_cast<size_t>(schedule_3D) << (bits - 4) ^
        static_cast<size_t>(flip_grid) << (bits - 5) ^
        static_cast<size_t>(cross_block_inner_reduction) << (bits - 6) ^
        static_cast<size_t>(cross_grid_inner_reduction) << (bits - 7) ^
        static_cast<size_t>(unroll_factor_inner_reduction) << (bits - 8) ^
        static_cast<size_t>(vectorize_inner_reduction) << (bits - 9) ^
        static_cast<size_t>(split_grid_dim_inner_reduction) << (bits - 10) ^
        static_cast<size_t>(pad_inner_reduction_to_warp) << (bits - 11) ^
        static_cast<size_t>(batches_per_block_inner_reduction) << (bits - 12) ^
        static_cast<size_t>(multiple_reds_per_blk) << (bits - 13) ^
        static_cast<size_t>(unroll_factor_iter_dom) << (bits - 14) ^
        static_cast<size_t>(vectorize_iter_dom) << (bits - 15) ^
        static_cast<size_t>(split_grid_dim_iter_dom) << (bits - 16) ^
        static_cast<size_t>(cross_block_outer_reduction) << (bits - 17) ^
        static_cast<size_t>(cross_grid_outer_reduction) << (bits - 18) ^
        static_cast<size_t>(split_grid_dim_outer_reduction) << (bits - 19) ^
        static_cast<size_t>(batches_per_block_outer_reduction) << (bits - 20) ^
        static_cast<size_t>(unroll_factor_outer_reduction) << (bits - 21);
    return attr_hash;
  }

  std::shared_ptr<HeuristicParams> clone() const override {
    return std::make_shared<ReductionParams>(*this);
  }
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch