1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
#pragma once
#include <atomic>
#include <memory>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/update_graph_executor_opt.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/runtime/variable_tensor_list.h>
C10_DECLARE_bool(torch_jit_enable_new_executor);
namespace torch {
namespace jit {
struct GraphExecutorState;
struct Code;
struct ExecutionPlan {
ExecutionPlan() = default;
ExecutionPlan(
std::shared_ptr<Graph> graph,
std::string function_name,
size_t remaining_bailout_depth = 0)
: code(graph, std::move(function_name), remaining_bailout_depth),
graph(std::move(graph)) {}
operator bool() const {
return static_cast<bool>(graph);
}
Code code;
std::shared_ptr<Graph> graph;
};
// Notice that those structs don't manage lifetime of their members.
// They is only valid only right after you call getDebugState() and should never
// be used again once another GraphExecutor function is called.
struct GraphExecutorState {
const Graph* graph = nullptr;
ExecutionPlan fallback; // XXX: members of this field are optional
std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
};
struct TORCH_API EnableProfilingGuard {
EnableProfilingGuard();
~EnableProfilingGuard();
private:
bool old_executor_mode = false;
bool old_profiling_mode = false;
};
struct GraphExecutorImplBase;
struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, std::string function_name);
void run(Stack& inputs);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
// `remaining_bailout_depth` stands for the maximum number of profiled and
// specialized recompilations allowed for the current `GraphExecutor`. if
// remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
// profiling and specialization. This is also equivalent to the
// SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
// `GraphExecutor` will profile and specialize its input graph based on the
// profiled information whenever a bailout check is failed/triggered, a new
// `GraphExecutor` will be created. This new `GraphExecutor`'s
// remaining_bailout_depth will be reduced by 1.
ExecutionPlan getPlanFor(Stack& inputs, size_t remaining_bailout_depth);
explicit operator bool() const {
return pImpl != nullptr;
}
void reset() {
pImpl.reset();
}
std::shared_ptr<Graph> graph() const;
GraphExecutorState getDebugState();
static size_t getDefaultNumBailOuts();
private:
std::shared_ptr<GraphExecutorImplBase> pImpl;
};
TORCH_API Node* replaceBlockWithFallbackGraph(
Block* b,
ArrayRef<Value*> inputs);
// These passes need to run before it is valid to pass to the interpreter
// regardless of whether sizes have been specialized or not.
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
TORCH_API std::atomic<bool>& getProfilingMode();
TORCH_API std::atomic<bool>& getExecutorMode();
TORCH_API std::atomic<size_t>& getNumProfiledRuns();
TORCH_API std::atomic<size_t>& getBailoutDepth();
TORCH_API bool IsNewExecutorEnabled();
struct TORCH_API GraphOptimizerEnabledGuard {
GraphOptimizerEnabledGuard(bool state)
: old_state_(getGraphExecutorOptimize()) {
setGraphExecutorOptimize(state);
}
~GraphOptimizerEnabledGuard() {
setGraphExecutorOptimize(old_state_);
}
bool old_state_;
};
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
// for debugging information we expose a way to get the last actually
// run graph. Previous approaches allowed querying the GraphExecutor
// for what graph it would run in certain circumstances (graphFor), but
// this is fragile because we sometimes change how these decisions are made.
// This interface still allows our tests to look at optimized graphs, but
// with less plumbing.
} // namespace detail
} // namespace jit
} // namespace torch
|