1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <ATen/DynamicLibrary.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/native/NonSymbolicBC.h>
#include <ATen/native/TensorShape.h>
#include <c10/util/CallOnce.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
namespace torch::jit::fuser::cuda {
static std::atomic<bool> cuda_fusion_guard_mode{true};
bool isEnabled() {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::isEnabled() is deprecated");
return false;
}
bool setEnabled(bool is_enabled) {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::setEnabled() is deprecated");
TORCH_INTERNAL_ASSERT(
!is_enabled,
"nvfuser support in torchscript is removed and cannot be enabled!");
return false;
}
bool canBeEnabled() {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::nvfuserCanBeEnabled() is deprecated");
return false;
}
bool getSingletonFusion() {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::getSingletonFusion() is deprecated");
return false;
}
bool setSingletonFusion(bool value) {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::setSingletonFusion() is deprecated");
TORCH_INTERNAL_ASSERT(
!value,
"nvfuser support in torchscript is removed and singleton fusion cannot be enabled!");
return false;
}
bool getHorizontalFusion() {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::getHorizontalFusion() is deprecated");
return false;
}
bool setHorizontalFusion(bool value) {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::setHorizontalFusion() is deprecated");
TORCH_INTERNAL_ASSERT(
!value,
"nvfuser support in torchscript is removed and horizontal fusion cannot be enabled!");
return false;
}
std::atomic<bool>& getCudaFusionGuardMode() {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::getCudaFusionGuardMode() is deprecated");
return cuda_fusion_guard_mode;
}
CudaFuserInterface* getFuserInterface() {
static CudaFuserInterface fuser_interface_;
return &fuser_interface_;
}
void compileFusionGroup(Node* fusion_node) {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::compileFusionGroup() is deprecated");
TORCH_CHECK(
getFuserInterface()->fn_compile_n != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_compile_n(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::runFusionGroup() is deprecated");
TORCH_CHECK(
getFuserInterface()->fn_run_n_s != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_run_n_s(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
if (!isEnabled()) {
return;
}
TORCH_WARN_ONCE("nvfuser integration in TorchScript is deprecated.");
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
}
bool canFuseNode(const Node* node) {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::canFuseNode() is deprecated");
return getFuserInterface()->fn_can_fuse_n != nullptr &&
getFuserInterface()->fn_can_fuse_n(node);
}
void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
TORCH_WARN_ONCE(
"torch::jit::fuser::cuda::InsertProfileNodesForCUDAFuser() is deprecated");
if (getFuserInterface()->fn_insert_profile_inodes) {
getFuserInterface()->fn_insert_profile_inodes(pr);
}
}
bool profileNode(const Node* node) {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::profileNode() is deprecated");
return getFuserInterface()->fn_profile_n != nullptr &&
getFuserInterface()->fn_profile_n(node);
}
bool skipNode(const std::string& symbol_str, bool flip) {
TORCH_WARN_ONCE("torch::jit::fuser::cuda::skipNode() is deprecated");
return getFuserInterface()->fn_skip_n != nullptr &&
getFuserInterface()->fn_skip_n(symbol_str, flip);
}
} // namespace torch::jit::fuser::cuda
|