File: executor_launch_params.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (136 lines) | stat: -rw-r--r-- 3,406 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class TORCH_CUDA_CU_API LaunchParams {
 public:
  static constexpr int64_t UNINITIALIZED_VAL = -1;

  LaunchParams(
      int64_t gdimx = UNINITIALIZED_VAL,
      int64_t gdimy = UNINITIALIZED_VAL,
      int64_t gdimz = UNINITIALIZED_VAL,
      int64_t bdimx = UNINITIALIZED_VAL,
      int64_t bdimy = UNINITIALIZED_VAL,
      int64_t bdimz = UNINITIALIZED_VAL)
      : gdimx_(gdimx),
        gdimy_(gdimy),
        gdimz_(gdimz),
        bdimx_(bdimx),
        bdimy_(bdimy),
        bdimz_(bdimz) {
    assertValid();
  }

  void assertValid();

  void setSmem(int64_t smem) {
    smem_ = smem;
  }

  int64_t smem() const {
    return smem_;
  }

  int64_t nBlocks() const {
    return std::abs(gdimx_ * gdimy_ * gdimz_);
  }

  int64_t nThreads() const {
    return std::abs(bdimx_ * bdimy_ * bdimz_);
  }

  int64_t bdimx() const {
    return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_);
  }

  int64_t gdimx() const {
    return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_);
  }

  int64_t bdimy() const {
    return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_);
  }

  int64_t gdimy() const {
    return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_);
  }

  int64_t bdimz() const {
    return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_);
  }

  int64_t gdimz() const {
    return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_);
  }

  void checkAndSet(
      const int64_t incoming_val,
      int64_t& class_val,
      std::string val) {
    TORCH_INTERNAL_ASSERT(
        class_val == UNINITIALIZED_VAL || incoming_val == class_val,
        "Tried to set ",
        val,
        " from ",
        class_val,
        " to ",
        incoming_val,
        ", but it was already set and new value does not match.",
        " Thread dims all have to be bound to the same value.");
    TORCH_CHECK(
        incoming_val > 0,
        "Received a thread binding on ",
        val,
        " that is ",
        incoming_val,
        ". Cannot create negative threads.");
    if (class_val == UNINITIALIZED_VAL) {
      class_val = incoming_val;
    }
    assertValid();
  }

  // Binds dim assocaited with p_type to val
  void bind(int64_t val, ParallelType p_type);

  // Adjusted value based on get functions above for each value
  int64_t getDim(ParallelType p_type) const;

  // Returns raw value which may be UNINITIALIZED_VAL
  const int64_t& getRawVal(ParallelType p_type) const;

  // Returns false if value associated with p_type == UNINITIALIZED_VAL
  bool hasDim(ParallelType p_type) const;

  bool operator==(const LaunchParams& other) const;

  void print() const;

  std::string toString() const;

 private:
  // Spell them out because I want signed ints to know if they were initialized
  // or not.
  // TODO: convert to c10::optional
  int64_t gdimx_ = UNINITIALIZED_VAL;
  int64_t gdimy_ = UNINITIALIZED_VAL;
  int64_t gdimz_ = UNINITIALIZED_VAL;
  int64_t bdimx_ = UNINITIALIZED_VAL;
  int64_t bdimy_ = UNINITIALIZED_VAL;
  int64_t bdimz_ = UNINITIALIZED_VAL;

  int64_t smem_ = 0;

  // TODO: Fill in output sizes
  std::vector<std::vector<int64_t>> output_sizes;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch