1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/peephole.h>
#ifndef C10_MOBILE
#include <ATen/autocast_mode.h>
#include <torch/csrc/jit/passes/autocast.h>
#endif
namespace torch {
namespace jit {
namespace {
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
std::vector<c10::Argument> args;
std::vector<c10::Argument> returns;
Graph& g = *function.graph();
size_t num_inputs = function.num_inputs();
for (const auto i : c10::irange(num_inputs)) {
const Value* v = g.inputs().at(i);
std::string name = v->hasDebugName() ? v->debugNameBase()
: ("argument_" + c10::to_string(i));
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
}
for (const auto i : c10::irange(g.outputs().size())) {
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
}
return {function.name(), "", std::move(args), std::move(returns)};
}
template <typename T, typename F>
T* tryToGraphFunctionImpl(F& function) noexcept {
if (!function.isGraphFunction()) {
return nullptr;
}
return static_cast<T*>(&function);
}
template <typename T, typename F>
T& toGraphFunctionImpl(F& function) {
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
return *g;
}
TORCH_INTERNAL_ASSERT(
false,
"Failed to downcast a Function to a GraphFunction. "
"This probably indicates that the JIT calling context needs a "
"special case on tryToGraphFunction() instead.");
}
} // namespace
void placeholderCreator(GraphFunction&) {
throw RecursiveMethodCallError();
}
void GraphFunction::run(Stack& stack) {
get_executor().run(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
Stack& stack,
TaskLauncher taskLauncher) {
return get_executor().runAsync(stack, std::move(taskLauncher));
}
void GraphFunction::ensure_defined() {
if (function_creator_) {
auto creator = function_creator_;
function_creator_ = placeholderCreator;
creator(*this);
function_creator_ = nullptr;
}
check_single_output();
}
const c10::FunctionSchema& GraphFunction::getSchema() const {
if (schema_ == nullptr) {
schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
}
return *schema_;
}
GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
if (force_no_amp_) {
return SpecializationKey::AutocastOff;
}
#ifdef C10_MOBILE
// disabling autodiff pass for mobile build since autocast APIs don't exist
return SpecializationKey::AutocastOff;
#else
bool cpu_enabled = at::autocast::is_cpu_enabled();
bool gpu_enabled = at::autocast::is_enabled();
if (cpu_enabled && gpu_enabled) {
return SpecializationKey::CpuGpuAutocastOn;
} else if (!cpu_enabled && !gpu_enabled) {
return SpecializationKey::AutocastOff;
} else {
return gpu_enabled ? SpecializationKey::GpuAutocastOn
: SpecializationKey::CpuAutocastOn;
}
#endif
}
void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast) {
Inline(*graph);
// Peephole Optimize cleans up many "is None" checks and creates constant prop
// opportunities
PeepholeOptimize(graph, true);
// AliasDb construction can be slow, so run it just on immutable types
// to clean up constant Ifs & other easy wins
ConstantPropagationImmutableTypes(graph);
#ifndef C10_MOBILE
// Inject casts for automatic mixed precision
//
// TODO: Ideally, this pass could run earlier, before inlining
// or any other optimizations. That setup is preferable because:
// 1. The AMP pass would be self-contained and function independently
// of the any optimizations
// 2. AMP transformations would benefit from followup passes's cleanup
//
if (!disable_autocast) {
Autocast(graph);
}
#endif
ConstantPooling(graph);
}
GraphFunction* tryToGraphFunction(Function& function) noexcept {
return tryToGraphFunctionImpl<GraphFunction>(function);
}
GraphFunction& toGraphFunction(Function& function) {
return toGraphFunctionImpl<GraphFunction>(function);
}
const GraphFunction& toGraphFunction(const Function& function) {
return toGraphFunctionImpl<const GraphFunction>(function);
}
} // namespace jit
} // namespace torch
|