File: executor_launch_params.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (124 lines) | stat: -rw-r--r-- 3,225 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class TORCH_CUDA_API LaunchParams {
 public:
  static constexpr int64_t UNINITIALIZED_VAL = -1;

  LaunchParams(
      int64_t gdimx = UNINITIALIZED_VAL,
      int64_t gdimy = UNINITIALIZED_VAL,
      int64_t gdimz = UNINITIALIZED_VAL,
      int64_t bdimx = UNINITIALIZED_VAL,
      int64_t bdimy = UNINITIALIZED_VAL,
      int64_t bdimz = UNINITIALIZED_VAL)
      : gdimx_(gdimx),
        gdimy_(gdimy),
        gdimz_(gdimz),
        bdimx_(bdimx),
        bdimy_(bdimy),
        bdimz_(bdimz) {}

  void setSmem(int64_t smem) {
    smem_ = smem;
  }

  int64_t smem() const {
    return smem_;
  }

  int64_t nBlocks() const {
    return gdimx_ * gdimy_ * gdimz_;
  }

  int64_t nThreads() const {
    return bdimx_ * bdimy_ * bdimz_;
  }

  int64_t bdimx() const {
    return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_);
  }

  int64_t gdimx() const {
    return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_);
  }

  int64_t bdimy() const {
    return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_);
  }

  int64_t gdimy() const {
    return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_);
  }

  int64_t bdimz() const {
    return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_);
  }

  int64_t gdimz() const {
    return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_);
  }

  void checkAndSet(
      const int64_t incoming_val,
      int64_t& class_val,
      std::string val) {
    TORCH_INTERNAL_ASSERT(
        class_val == UNINITIALIZED_VAL || incoming_val == class_val,
        "Tried to set ",
        val,
        " to ",
        incoming_val,
        ", but it was already set and new value does not match.",
        " Thread dims all have to be bound to the same value.");
    TORCH_CHECK(
        incoming_val > 0,
        "Received a thread binding on ",
        val,
        " that is ",
        incoming_val,
        ". Cannot create negative threads.");
    if (class_val == UNINITIALIZED_VAL) {
      class_val = incoming_val;
    }
  }

  // Binds dim assocaited with p_type to val
  void bind(int64_t val, ParallelType p_type);

  // Adjusted value based on get functions above for each value
  int64_t getDim(ParallelType p_type) const;

  // Returns raw value which may be UNINITIALIZED_VAL
  const int64_t& getRawVal(ParallelType p_type) const;

  // Returns false if value associated with p_type == UNINITIALIZED_VAL
  bool hasDim(ParallelType p_type) const;

  bool operator==(const LaunchParams& other) const;

 private:
  // Spell them out because I want signed ints to know if they were initialized
  // or not.
  // TODO: convert to c10::optional
  int64_t gdimx_ = UNINITIALIZED_VAL;
  int64_t gdimy_ = UNINITIALIZED_VAL;
  int64_t gdimz_ = UNINITIALIZED_VAL;
  int64_t bdimx_ = UNINITIALIZED_VAL;
  int64_t bdimy_ = UNINITIALIZED_VAL;
  int64_t bdimz_ = UNINITIALIZED_VAL;

  int64_t smem_ = 0;

  // TODO: Fill in output sizes
  std::vector<std::vector<int64_t>> output_sizes;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch