1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
class TORCH_CUDA_CU_API LaunchParams {
public:
static constexpr int64_t UNINITIALIZED_VAL = -1;
LaunchParams(
int64_t gdimx = UNINITIALIZED_VAL,
int64_t gdimy = UNINITIALIZED_VAL,
int64_t gdimz = UNINITIALIZED_VAL,
int64_t bdimx = UNINITIALIZED_VAL,
int64_t bdimy = UNINITIALIZED_VAL,
int64_t bdimz = UNINITIALIZED_VAL)
: gdimx_(gdimx),
gdimy_(gdimy),
gdimz_(gdimz),
bdimx_(bdimx),
bdimy_(bdimy),
bdimz_(bdimz) {
assertValid();
}
void assertValid();
void setSmem(int64_t smem) {
smem_ = smem;
}
int64_t smem() const {
return smem_;
}
int64_t nBlocks() const {
return std::abs(gdimx_ * gdimy_ * gdimz_);
}
int64_t nThreads() const {
return std::abs(bdimx_ * bdimy_ * bdimz_);
}
int64_t bdimx() const {
return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_);
}
int64_t gdimx() const {
return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_);
}
int64_t bdimy() const {
return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_);
}
int64_t gdimy() const {
return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_);
}
int64_t bdimz() const {
return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_);
}
int64_t gdimz() const {
return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_);
}
void checkAndSet(
const int64_t incoming_val,
int64_t& class_val,
std::string val) {
TORCH_INTERNAL_ASSERT(
class_val == UNINITIALIZED_VAL || incoming_val == class_val,
"Tried to set ",
val,
" from ",
class_val,
" to ",
incoming_val,
", but it was already set and new value does not match.",
" Thread dims all have to be bound to the same value.");
TORCH_CHECK(
incoming_val > 0,
"Received a thread binding on ",
val,
" that is ",
incoming_val,
". Cannot create negative threads.");
if (class_val == UNINITIALIZED_VAL) {
class_val = incoming_val;
}
assertValid();
}
// Binds dim assocaited with p_type to val
void bind(int64_t val, ParallelType p_type);
// Adjusted value based on get functions above for each value
int64_t getDim(ParallelType p_type) const;
// Returns raw value which may be UNINITIALIZED_VAL
const int64_t& getRawVal(ParallelType p_type) const;
// Returns false if value associated with p_type == UNINITIALIZED_VAL
bool hasDim(ParallelType p_type) const;
bool operator==(const LaunchParams& other) const;
void print() const;
std::string toString() const;
private:
// Spell them out because I want signed ints to know if they were initialized
// or not.
// TODO: convert to c10::optional
int64_t gdimx_ = UNINITIALIZED_VAL;
int64_t gdimy_ = UNINITIALIZED_VAL;
int64_t gdimz_ = UNINITIALIZED_VAL;
int64_t bdimx_ = UNINITIALIZED_VAL;
int64_t bdimy_ = UNINITIALIZED_VAL;
int64_t bdimz_ = UNINITIALIZED_VAL;
int64_t smem_ = 0;
// TODO: Fill in output sizes
std::vector<std::vector<int64_t>> output_sizes;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|