1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
#pragma once
#include <atomic>
#include <memory>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/variable_tensor_list.h>
C10_DECLARE_bool(torch_jit_enable_new_executor);
namespace torch {
namespace jit {
struct GraphExecutorState;
struct Code;
enum ExecutorExecutionMode {
SIMPLE,
PROFILING,
};
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
: code(graph, std::move(function_name)), graph(std::move(graph)) {}
operator bool() const {
return static_cast<bool>(graph);
}
Code code;
std::shared_ptr<Graph> graph;
};
// Notice that those structs don't manage lifetime of their members.
// They are only valid only right after you call getDebugState() and should
// never be used again once another GraphExecutor function is called.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GraphExecutorState {
const Graph* graph = nullptr;
ExecutionPlan fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
};
struct TORCH_API EnableProfilingGuard {
EnableProfilingGuard();
~EnableProfilingGuard();
private:
bool old_executor_mode = false;
bool old_get_optimize = false;
};
struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name);
GraphExecutor(
const std::shared_ptr<Graph>& graph,
std::string function_name,
ExecutorExecutionMode executor_mode);
void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
// profiling and specialization. This is also equivalent to the
// SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
// `GraphExecutor` will profile and specialize its input graph based on the
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
// If no bailout depth is passed, the depth will be initialized from the
// current global fusion strategy settings.
const ExecutionPlan& getPlanFor(
Stack& inputs,
c10::optional<size_t> remaining_bailout_depth = c10::nullopt);
GraphExecutorState getDebugState();
void debugFlushCompilationCache();
bool isOptimized() const;
private:
std::shared_ptr<GraphExecutorImplBase> pImpl;
};
TORCH_API Node* replaceBlockWithFallbackGraph(
Block* b,
ArrayRef<Value*> inputs);
// These passes need to run before it is valid to pass to the interpreter
// regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetFusionGroupInlining(bool state);
TORCH_API bool getFusionGroupInlining();
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
TORCH_API std::atomic<bool>& getProfilingMode();
TORCH_API std::atomic<bool>& getExecutorMode();
TORCH_API std::atomic<size_t>& getNumProfiledRuns();
TORCH_API size_t getBailoutDepth();
TORCH_API bool IsNewExecutorEnabled();
struct TORCH_API GraphOptimizerEnabledGuard {
GraphOptimizerEnabledGuard(bool state)
: old_state_(getGraphExecutorOptimize()) {
setGraphExecutorOptimize(state);
}
~GraphOptimizerEnabledGuard() {
setGraphExecutorOptimize(old_state_);
}
bool old_state_;
};
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
// for debugging information we expose a way to get the last actually
// run graph. Previous approaches allowed querying the GraphExecutor
// for what graph it would run in certain circumstances (graphFor), but
// this is fragile because we sometimes change how these decisions are made.
// This interface still allows our tests to look at optimized graphs, but
// with less plumbing.
} // namespace detail
} // namespace jit
} // namespace torch
|