File: executor_launch_params.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (134 lines) | stat: -rw-r--r-- 3,889 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>

#include <ATen/cuda/CUDAContext.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

void LaunchParams::assertValid() {
  TORCH_INTERNAL_ASSERT(
      bdimx() * bdimy() * bdimz() > 0 &&
          bdimx() * bdimy() * bdimz() <=
              (int64_t)at::cuda::getCurrentDeviceProperties()
                  ->maxThreadsPerMultiProcessor,
      "Selected invalid number of threads for cuda: ",
      bdimx() * bdimy() * bdimz());
  TORCH_INTERNAL_ASSERT(
      gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1,
      "Invalid number of blocks in x direction: ",
      gdimx());
  TORCH_INTERNAL_ASSERT(
      gdimy() > 0 && gdimy() <= 65535,
      "Invalid number of blocks in y direction: ",
      gdimy());
  TORCH_INTERNAL_ASSERT(
      gdimz() > 0 && gdimz() <= 65535,
      "Invalid number of blocks in z direction: ",
      gdimz());
}

void LaunchParams::bind(int64_t val, ParallelType p_type) {
  switch (p_type) {
    case ParallelType::TIDx:
      checkAndSet(val, bdimx_, "blockDim.x");
      break;
    case ParallelType::BIDx:
      checkAndSet(val, gdimx_, "gridDim.x");
      break;
    case ParallelType::TIDy:
      checkAndSet(val, bdimy_, "blockDim.y");
      break;
    case ParallelType::BIDy:
      checkAndSet(val, gdimy_, "gridDim.y");
      break;
    case ParallelType::TIDz:
      checkAndSet(val, bdimz_, "blockdim.z");
      break;
    case ParallelType::BIDz:
      checkAndSet(val, gdimz_, "gridDim.z");
      break;
    default:
      TORCH_INTERNAL_ASSERT(
          false,
          "Tried to bind invalid parallel type in launch config: ",
          p_type);
  }
  assertValid();
}

int64_t LaunchParams::getDim(ParallelType p_type) const {
  switch (p_type) {
    case ParallelType::TIDx:
      return bdimx();
    case ParallelType::BIDx:
      return gdimx();
    case ParallelType::TIDy:
      return bdimy();
    case ParallelType::BIDy:
      return gdimy();
    case ParallelType::TIDz:
      return bdimz();
    case ParallelType::BIDz:
      return gdimz();
    default:
      TORCH_INTERNAL_ASSERT(
          false,
          "Tried to get with invalid parallel type in launch config: ",
          p_type);
  }
}

bool LaunchParams::hasDim(ParallelType p_type) const {
  return getRawVal(p_type) != UNINITIALIZED_VAL;
}

const int64_t& LaunchParams::getRawVal(ParallelType p_type) const {
  switch (p_type) {
    case ParallelType::TIDx:
      return bdimx_;
    case ParallelType::BIDx:
      return gdimx_;
    case ParallelType::TIDy:
      return bdimy_;
    case ParallelType::BIDy:
      return gdimy_;
    case ParallelType::TIDz:
      return bdimz_;
    case ParallelType::BIDz:
      return gdimz_;
    default:
      TORCH_INTERNAL_ASSERT(
          false,
          "Tried to get with invalid parallel type in launch config: ",
          p_type);
  }
}

bool LaunchParams::operator==(const LaunchParams& other) const {
  return gdimx_ == other.gdimx_ && gdimy_ == other.gdimy_ &&
      bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_;
}

void LaunchParams::print() const {
  std::cout << toString();
}

std::string LaunchParams::toString() const {
  std::stringstream ss;
  ss << "Launch Parameters: "
     << "BlockDim.x = " << (bdimx_ == UNINITIALIZED_VAL ? -1 : bdimx_) << ", "
     << "BlockDim.y = " << (bdimy_ == UNINITIALIZED_VAL ? -1 : bdimy_) << ", "
     << "BlockDim.z = " << (bdimz_ == UNINITIALIZED_VAL ? -1 : bdimz_) << ", "
     << "GridDim.x = " << (gdimx_ == UNINITIALIZED_VAL ? -1 : gdimx_) << ", "
     << "GridDim.y = " << (gdimy_ == UNINITIALIZED_VAL ? -1 : gdimy_) << ", "
     << "GridDim.z = " << (gdimz_ == UNINITIALIZED_VAL ? -1 : gdimz_) << ", "
     << "Smem Size = " << smem() << "\n";
  return ss.str();
}

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch