File: pointwise_heuristic.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (109 lines) | stat: -rw-r--r-- 3,667 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#pragma once

#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>

#include <sstream>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

// Parameters of the pointwise heuristic to describe the optimial schedule.
// Warning: equal operator is intended for use in caching the kernel associated
// with these pointwise parameters. It does not check if the launch parameters
// are equivelent!
class PointwiseParams : public HeuristicParams {
 public:
  // vectorize if true, otherwise unroll
  bool vectorize = false;

  // Treat pointwise operation as 2-Dimensional, this is the location where we
  // split from left side of the domain to right. i.e. 0 means problem is
  // treated as 1-D, 1 of 3 would mean we treat the first dimension as the outer
  // dimension, and all the others as an inner dimension.
  int break_point = 0;

  // Split block across left and right dimension
  bool split_block = false;

  // Split grid y dimension, if otherwise it would be too large
  bool split_grid_y_dim = false;

  // For many instances having BIDx on the inner most dimension is the most
  // performant parallel binding. However, if we're broadcasting the outer
  // dimension with a large inner dimension, it can be more performant to bind
  // BIDy on the inner most dimension.
  bool flip_grid_binding = false;

  // Unroll or vectorization factor
  size_t unroll_factor = 1;

  using HeuristicParams::HeuristicParams;

  // Warning: Does not check launch parameters!
  bool sameAs(
      const std::shared_ptr<HeuristicParams>& other_base) const override {
    auto other_casted = std::dynamic_pointer_cast<PointwiseParams>(other_base);
    if (other_casted == nullptr) {
      return false;
    }
    const PointwiseParams& other = *other_casted;
    bool attr_equal = other.vectorize == vectorize &&
        other.break_point == break_point && other.split_block == split_block &&
        other.split_grid_y_dim == split_grid_y_dim &&
        other.unroll_factor == unroll_factor &&
        other.flip_grid_binding == flip_grid_binding;
    return attr_equal;
  }

  std::string toString() const override {
    std::stringstream ss;
    ss << "\n===== Pointwise Parameters ========\n"
       << (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n"
       << " Gridx: " << lparams.gdimx() << " BlckY: " << lparams.bdimy()
       << " BlckX: " << lparams.bdimx() << "\n";
    if (break_point) {
      ss << "2D Schedule\n"
         << "  Bcast break point: " << break_point << "\n";
      if (split_block) {
        ss << "Split block into y-dim\n";
      }
      if (split_grid_y_dim) {
        ss << "  Split y grid dim\n";
      }
    }
    if (unroll_factor > 1) {
      if (vectorize) {
        ss << "Vectorize, Factor: " << unroll_factor << "\n";
      } else {
        ss << "Unroll, Factor: " << unroll_factor << "\n";
      }
    }
    if (flip_grid_binding) {
      ss << "Flip BIDx/BIDy bindings\n";
    }
    ss << "====================================\n";
    return ss.str();
  }

  // Warning: Hash is not based on launch parameters!
  size_t hash() const override {
    size_t attr_hash = static_cast<size_t>(vectorize) ^
        static_cast<size_t>(break_point) << 4 ^
        static_cast<size_t>(split_block) << 5 ^
        static_cast<size_t>(split_grid_y_dim) << 6 ^
        static_cast<size_t>(unroll_factor) << 9 ^
        static_cast<size_t>(flip_grid_binding) << 10;
    return attr_hash;
  }

  std::shared_ptr<HeuristicParams> clone() const override {
    return std::make_shared<PointwiseParams>(*this);
  }
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch