1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/stack.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <cstdint>
#include <memory>
#include <vector>
namespace torch {
namespace jit {
constexpr int kCPUDevice = -1;
// Assigns a "key" to the given fusion_group that it can use to run its
// fusion later (via runFusion() below).
TORCH_API int64_t registerFusion(const Node* fusion_group);
// Runs the fusion corresponding to the given key on the inputs
// found on the stack. Outputs are placed on the same stack.
// In some cases a fusion cannot be run and a fallback path where
// PyTorch's interpreter runs the graph instead is attempted.
TORCH_API void runFusion(const int64_t key, Stack& stack);
// True if the respective devices can fuse, false otherwise
TORCH_API bool canFuseOnCPU();
TORCH_API bool canFuseOnGPU();
// Sets whether fusion on the CPU is allowed (disabled by default due to
// flakiness)
TORCH_API void overrideCanFuseOnCPU(bool value);
// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval
TORCH_API void overrideMustUseLLVMOnCPU(bool value);
// Sets whether fusion on the GPU is allowed (enabled by default)
TORCH_API void overrideCanFuseOnGPU(bool value);
// Treats the given graph as a fusion group and launches it on the
// specified device with the given inputs.
// Returns the outputs.
TORCH_API std::vector<at::Tensor> debugLaunchGraph(
Graph& graph,
at::ArrayRef<at::Tensor> inputs);
// Treats the given graph as a fusion group and returns the generated code.
TORCH_API std::string debugGetFusedKernelCode(
Graph& graph,
at::ArrayRef<at::Tensor> inputs);
TORCH_API size_t nCompiledKernels();
} // namespace jit
} // namespace torch
|