1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
|
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <ATen/ATen.h>
#include <ATen/core/jit_type.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/fuser/codegen.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/codegen/fuser/tensor_desc.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <atomic>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <unordered_set>
#include <utility>
namespace {
std::mutex& fusionBackendLock() {
static std::mutex fusion_backends_lock_{};
return fusion_backends_lock_;
}
} // namespace
namespace torch {
namespace jit {
namespace fuser {
static std::unordered_map<at::Device::Type, FusedKernelConstructor>&
getFusionBackends() {
static std::unordered_map<at::Device::Type, FusedKernelConstructor>
fusion_backends;
return fusion_backends;
}
void registerFusionBackend(
at::Device::Type backend_type,
FusedKernelConstructor ctor) {
std::lock_guard<std::mutex> guard(fusionBackendLock());
getFusionBackends()[backend_type] = std::move(ctor);
}
bool hasFusionBackend(at::Device::Type backend_type) {
std::lock_guard<std::mutex> guard(fusionBackendLock());
return getFusionBackends().count(backend_type);
}
const FusedKernelConstructor& getConstructor(at::Device::Type backend_type) {
std::lock_guard<std::mutex> guard(fusionBackendLock());
return getFusionBackends().at(backend_type);
}
// Counter for number of kernels compiled, used for debugging and
// creating arbitrary kernel names.
static std::atomic<size_t> next_kernel_id{0};
static int debug_fusion{-1};
size_t nCompiledKernels() {
return next_kernel_id.load();
}
int debugFuser() {
if (debug_fusion < 0) {
const char* debug_env = getenv("PYTORCH_FUSION_DEBUG");
debug_fusion = debug_env ? atoi(debug_env) : 0;
}
return debug_fusion;
}
// If the given node is used once by a chunk node, returns that node.
// Returns nullptr otherwise.
static const Node* usedInFusedChunk(const Value* input) {
const auto& uses = input->uses();
if (uses.size() == 1) {
const Node* user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {
return user;
}
}
return nullptr;
}
static void setInputChunkDescriptors(KernelSpec& spec) {
// We only have as many chunk descriptors as tensor inputs,
// furthermore we know that the tensor inputs are in the
// beginning of the fusion group's inputs.
spec.inputChunks().reserve(spec.nTensorInputs());
for (const auto i : c10::irange(spec.nTensorInputs())) {
const Value* input = spec.graph()->inputs()[i];
if (const Node* chunk = usedInFusedChunk(input)) {
spec.inputChunks().emplace_back(
chunk->i(attr::chunks), chunk->i(attr::dim));
} else {
spec.inputChunks().emplace_back(1, 0);
}
}
}
// Run a DFS traversal to find all inputs that affect a given output value
static std::vector<int64_t> getInputDependencies(const Value* output) {
std::vector<const Value*> queue{output};
std::unordered_set<const Value*> inputs;
std::unordered_set<const Value*> seen;
while (!queue.empty()) {
const Value* val = queue.back();
queue.pop_back();
const Node* producer = val->node();
// Here we assume that only tensor inputs are used in
// the computation of the outputs.
// This is currently true, as the only inputs will be
// sizes (for _grad_sum_to_size as the derivative
// of broadcasts), which will only be used after
// the fusion kernel, and Tensors.
// This needs to be revisited when you start allowing
// other things e.g. nonconstant scalars.
if (producer->kind() == prim::Param &&
val->type()->isSubtypeOf(*TensorType::get())) {
inputs.insert(val);
continue;
}
for (const Value* input : producer->inputs()) {
if (/*bool inserted = */ seen.insert(input).second) {
queue.push_back(input);
}
}
}
// Convert Value* into offsets into the graph's input list
std::vector<int64_t> offsets;
offsets.reserve(inputs.size());
for (const Value* input : inputs) {
offsets.push_back(input->offset());
}
std::sort(offsets.begin(), offsets.end());
return offsets;
}
static void setInputBroadcastGroups(KernelSpec& spec) {
std::unordered_set<std::vector<int64_t>, c10::hash<std::vector<int64_t>>>
broadcast_groups;
for (const Value* output : (spec.graph())->outputs()) {
if (output->node()->kind() == prim::FusedConcat) {
for (const Value* concat_input : output->node()->inputs()) {
broadcast_groups.insert(getInputDependencies(concat_input));
}
} else {
broadcast_groups.insert(getInputDependencies(output));
}
}
std::copy(
broadcast_groups.begin(),
broadcast_groups.end(),
std::back_inserter(spec.inputBroadcastGroups()));
}
// Performs "upfront" compilation where storage is known but shapes are not.
// Currently identifies how to expand all tensors so that all intermediate
// tensors are the same shape, simplifying code generation.
// Broadcast groups and chunks are identified without shape information
// using logical properties of how each works. In particular, tensors
// are always expandable to the outputs of pointwise operations they
// or their descendants are involved in, which means that in a DAG of
// pointwise operations all tensors are expandable to the (single) output.
// Note: The logic is slightly complicated by concatenation and chunking.
static void upfrontCompilation(KernelSpec& spec) {
setInputBroadcastGroups(spec);
setInputChunkDescriptors(spec);
}
int64_t registerFusion(const Node* fusion_group) {
auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph));
// Don't re-register the fusion if we can use a pre-existing one
const auto maybe_spec = lookupGraph(graph);
if (maybe_spec) {
return (*maybe_spec)->key();
}
// Unconditionally create and register the fusion
// This is necessary to support our global disable fusions flag: if someone
// runs some code under no-fusions mode and then runs some code with fusions
// enabled, the second time around the returned spec from the cache should
// be a valid spec (must have had upfrontCompilation run on it).
const auto key = store(graph);
const auto maybe_retrieved_spec = retrieve(key);
AT_ASSERT(maybe_retrieved_spec);
upfrontCompilation(**maybe_retrieved_spec);
return key;
}
std::shared_ptr<FusedKernel> compileKernel(
const KernelSpec& spec,
const ArgSpec& arg_spec,
const std::vector<int64_t>& map_size,
const at::Device device) {
const std::vector<TensorDesc>& input_desc = arg_spec.descs();
auto graph = spec.graph()->copy();
for (const auto i : c10::irange(input_desc.size())) {
const auto& desc = input_desc[i];
// TODO: can't get rid of this use of TensorType
// until we switch to ProfilingGraphExecutor, so we don't have to
// run PropagateInputShapes below
graph->inputs()[i]->setType(TensorType::create(
desc.scalar_type,
device,
{desc.nDim()},
false)); // TODO: nDim is bad, as it is collapsed
}
PropagateInputShapes(graph);
// Creates chunk and flattened input descriptions
std::vector<PartitionDesc> chunk_desc;
std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>
flat_inputs;
{
size_t input_index = 0;
for (const auto& p : graph->inputs()) {
if (p->type()->isSubtypeOf(*FloatType::get())) {
flat_inputs.emplace_back(p, c10::nullopt);
}
if (!p->type()->isSubtypeOf(*TensorType::get())) {
continue;
}
if (const Node* chunk = usedInFusedChunk(p)) {
int64_t dim = chunk->i(attr::dim);
int64_t chunks = chunk->i(attr::chunks);
chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
for (const auto* o : chunk->outputs()) {
flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
}
} else {
chunk_desc.emplace_back();
flat_inputs.emplace_back(p, input_desc[input_index++]);
}
}
}
// Creates output, concat, and flattened output descriptions
std::vector<TensorDesc> output_desc;
std::vector<PartitionDesc> concat_desc;
std::vector<std::pair<const Value*, const TensorDesc>> flat_outputs;
for (const Value* o : graph->outputs()) {
// Creates output description
std::vector<int64_t> sizes = map_size;
if (o->node()->kind() == prim::FusedConcat) {
sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
}
auto scalar_type = o->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(scalar_type);
auto type = TensorType::createContiguous(*scalar_type, device, sizes);
output_desc.emplace_back(type);
const auto& desc = output_desc.back();
// Creates concat and flattened output descriptions (relies on output desc)
if (o->node()->kind() != prim::FusedConcat) {
concat_desc.emplace_back();
flat_outputs.emplace_back(o, desc);
} else {
const auto cat = o->node();
concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
for (const auto& c : cat->inputs()) {
flat_outputs.emplace_back(c, *concat_desc.back().subTensorDesc());
}
}
}
const bool use_cuda = device.is_cuda();
const std::string name = "kernel_" + c10::to_string(next_kernel_id++);
std::string code =
generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
const FusedKernelConstructor& kernel_ctor =
getConstructor(use_cuda ? DeviceType::CUDA : DeviceType::CPU);
return kernel_ctor(
device.index(),
name,
code,
input_desc,
output_desc,
chunk_desc,
concat_desc,
spec.hasRandom());
}
} // namespace fuser
} // namespace jit
} // namespace torch
|