1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
#pragma once
#include <c10/util/Flags.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
TORCH_DECLARE_bool(torch_jit_static_then_dynamic);
TORCH_DECLARE_bool(torch_jit_always_dynamic);
namespace torch::jit {
TORCH_API void runNooptPassPipeline(std::shared_ptr<Graph>& graph);
struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
ProfilingGraphExecutorImpl(
const std::shared_ptr<Graph>& graph,
std::string function_name);
const ExecutionPlan& getPlanFor(
Stack& stack,
std::optional<size_t> remaining_bailout_depth) override;
GraphExecutorState getDebugState() override;
~ProfilingGraphExecutorImpl() override = default;
void debugFlushCompilationCache();
bool isOptimized() const override {
return optimized_plan_.has_value();
}
private:
const ExecutionPlan& getOptimizedPlanFor(
Stack& stack,
std::optional<size_t> remaining_bailout_depth);
void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph);
void runProfilingOptimizations(
std::shared_ptr<Graph>& graph,
size_t remaining_depth);
void replaceFallbackGraphWithFallbackFunction(Block* b);
FusionBehavior getCurrentBehavior(size_t remaining_depth);
size_t getInstantiatedBailoutDepth();
void runNoGradOptimizations(
std::shared_ptr<Graph>& graph,
size_t remaining_bailout_depth);
void runFinalOptimizations(std::shared_ptr<Graph>& graph);
void clearTheGraphCompilationIntermediateGraphs();
std::unique_ptr<ProfilingRecord> pr_;
std::optional<ExecutionPlan>
profiling_plan_; // plan to run in order to profiling the code
std::optional<ExecutionPlan> optimized_plan_;
FusionStrategy fusion_strategy_;
// this plan is used if getGraphExecutorOptimize is unset
std::optional<ExecutionPlan> fallback_plan_;
// fallback functions are inserted for tensorexpr fusion groups
// and by specialize_autogradzero. Whenever, at runtime, input
// tensor don't match profiled properties, fallback functions are called
// They are the deoptimized version of the logic in fusion groups
// and/or autograd.
// The fallback functions are owned by a GraphExecutor instance
// They only exist in the optimized graph which is a private property
// of the GraphExecutor and only shared with InterpreterState
std::vector<std::unique_ptr<Function>> fallback_functions_;
std::optional<size_t> remaining_bailout_depth_;
// The time the optimized_plan_ is created.
int32_t time_optimized_plan_created_ = 0;
// Has the extra memory used by the graph for profiling is released?
bool is_graph_extra_memory_released_ = false;
};
} // namespace torch::jit
|