1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <ATen/cuda/CUDAContext.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
void LaunchParams::assertValid() {
TORCH_INTERNAL_ASSERT(
bdimx() * bdimy() * bdimz() > 0 &&
bdimx() * bdimy() * bdimz() <=
(int64_t)at::cuda::getCurrentDeviceProperties()
->maxThreadsPerMultiProcessor,
"Selected invalid number of threads for cuda: ",
bdimx() * bdimy() * bdimz());
TORCH_INTERNAL_ASSERT(
gdimx() > 0 && gdimx() < (std::int64_t(1) << 32) - 1,
"Invalid number of blocks in x direction: ",
gdimx());
TORCH_INTERNAL_ASSERT(
gdimy() > 0 && gdimy() <= 65535,
"Invalid number of blocks in y direction: ",
gdimy());
TORCH_INTERNAL_ASSERT(
gdimz() > 0 && gdimz() <= 65535,
"Invalid number of blocks in z direction: ",
gdimz());
}
void LaunchParams::bind(int64_t val, ParallelType p_type) {
switch (p_type) {
case ParallelType::TIDx:
checkAndSet(val, bdimx_, "blockDim.x");
break;
case ParallelType::BIDx:
checkAndSet(val, gdimx_, "gridDim.x");
break;
case ParallelType::TIDy:
checkAndSet(val, bdimy_, "blockDim.y");
break;
case ParallelType::BIDy:
checkAndSet(val, gdimy_, "gridDim.y");
break;
case ParallelType::TIDz:
checkAndSet(val, bdimz_, "blockdim.z");
break;
case ParallelType::BIDz:
checkAndSet(val, gdimz_, "gridDim.z");
break;
default:
TORCH_INTERNAL_ASSERT(
false,
"Tried to bind invalid parallel type in launch config: ",
p_type);
}
assertValid();
}
int64_t LaunchParams::getDim(ParallelType p_type) const {
switch (p_type) {
case ParallelType::TIDx:
return bdimx();
case ParallelType::BIDx:
return gdimx();
case ParallelType::TIDy:
return bdimy();
case ParallelType::BIDy:
return gdimy();
case ParallelType::TIDz:
return bdimz();
case ParallelType::BIDz:
return gdimz();
default:
TORCH_INTERNAL_ASSERT(
false,
"Tried to get with invalid parallel type in launch config: ",
p_type);
}
}
bool LaunchParams::hasDim(ParallelType p_type) const {
return getRawVal(p_type) != UNINITIALIZED_VAL;
}
const int64_t& LaunchParams::getRawVal(ParallelType p_type) const {
switch (p_type) {
case ParallelType::TIDx:
return bdimx_;
case ParallelType::BIDx:
return gdimx_;
case ParallelType::TIDy:
return bdimy_;
case ParallelType::BIDy:
return gdimy_;
case ParallelType::TIDz:
return bdimz_;
case ParallelType::BIDz:
return gdimz_;
default:
TORCH_INTERNAL_ASSERT(
false,
"Tried to get with invalid parallel type in launch config: ",
p_type);
}
}
bool LaunchParams::operator==(const LaunchParams& other) const {
return gdimx_ == other.gdimx_ && gdimy_ == other.gdimy_ &&
bdimx_ == other.bdimx_ && bdimy_ == other.bdimy_ && smem_ == other.smem_;
}
void LaunchParams::print() const {
std::cout << toString();
}
std::string LaunchParams::toString() const {
std::stringstream ss;
ss << "Launch Parameters: "
<< "BlockDim.x = " << (bdimx_ == UNINITIALIZED_VAL ? -1 : bdimx_) << ", "
<< "BlockDim.y = " << (bdimy_ == UNINITIALIZED_VAL ? -1 : bdimy_) << ", "
<< "BlockDim.z = " << (bdimz_ == UNINITIALIZED_VAL ? -1 : bdimz_) << ", "
<< "GridDim.x = " << (gdimx_ == UNINITIALIZED_VAL ? -1 : gdimx_) << ", "
<< "GridDim.y = " << (gdimy_ == UNINITIALIZED_VAL ? -1 : gdimy_) << ", "
<< "GridDim.z = " << (gdimz_ == UNINITIALIZED_VAL ? -1 : gdimz_) << ", "
<< "Smem Size = " << smem() << "\n";
return ss.str();
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|