1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/type.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
class TORCH_CUDA_API LaunchParams {
public:
static constexpr int64_t UNINITIALIZED_VAL = -1;
LaunchParams(
int64_t gdimx = UNINITIALIZED_VAL,
int64_t gdimy = UNINITIALIZED_VAL,
int64_t gdimz = UNINITIALIZED_VAL,
int64_t bdimx = UNINITIALIZED_VAL,
int64_t bdimy = UNINITIALIZED_VAL,
int64_t bdimz = UNINITIALIZED_VAL)
: gdimx_(gdimx),
gdimy_(gdimy),
gdimz_(gdimz),
bdimx_(bdimx),
bdimy_(bdimy),
bdimz_(bdimz) {}
void setSmem(int64_t smem) {
smem_ = smem;
}
int64_t smem() const {
return smem_;
}
int64_t nBlocks() const {
return gdimx_ * gdimy_ * gdimz_;
}
int64_t nThreads() const {
return bdimx_ * bdimy_ * bdimz_;
}
int64_t bdimx() const {
return static_cast<int64_t>(bdimx_ == UNINITIALIZED_VAL ? 1 : bdimx_);
}
int64_t gdimx() const {
return static_cast<int64_t>(gdimx_ == UNINITIALIZED_VAL ? 1 : gdimx_);
}
int64_t bdimy() const {
return static_cast<int64_t>(bdimy_ == UNINITIALIZED_VAL ? 1 : bdimy_);
}
int64_t gdimy() const {
return static_cast<int64_t>(gdimy_ == UNINITIALIZED_VAL ? 1 : gdimy_);
}
int64_t bdimz() const {
return static_cast<int64_t>(bdimz_ == UNINITIALIZED_VAL ? 1 : bdimz_);
}
int64_t gdimz() const {
return static_cast<int64_t>(gdimz_ == UNINITIALIZED_VAL ? 1 : gdimz_);
}
void checkAndSet(
const int64_t incoming_val,
int64_t& class_val,
std::string val) {
TORCH_INTERNAL_ASSERT(
class_val == UNINITIALIZED_VAL || incoming_val == class_val,
"Tried to set ",
val,
" to ",
incoming_val,
", but it was already set and new value does not match.",
" Thread dims all have to be bound to the same value.");
TORCH_CHECK(
incoming_val > 0,
"Received a thread binding on ",
val,
" that is ",
incoming_val,
". Cannot create negative threads.");
if (class_val == UNINITIALIZED_VAL) {
class_val = incoming_val;
}
}
// Binds dim assocaited with p_type to val
void bind(int64_t val, ParallelType p_type);
// Adjusted value based on get functions above for each value
int64_t getDim(ParallelType p_type) const;
// Returns raw value which may be UNINITIALIZED_VAL
const int64_t& getRawVal(ParallelType p_type) const;
// Returns false if value associated with p_type == UNINITIALIZED_VAL
bool hasDim(ParallelType p_type) const;
bool operator==(const LaunchParams& other) const;
private:
// Spell them out because I want signed ints to know if they were initialized
// or not.
// TODO: convert to c10::optional
int64_t gdimx_ = UNINITIALIZED_VAL;
int64_t gdimy_ = UNINITIALIZED_VAL;
int64_t gdimz_ = UNINITIALIZED_VAL;
int64_t bdimx_ = UNINITIALIZED_VAL;
int64_t bdimy_ = UNINITIALIZED_VAL;
int64_t bdimz_ = UNINITIALIZED_VAL;
int64_t smem_ = 0;
// TODO: Fill in output sizes
std::vector<std::vector<int64_t>> output_sizes;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|