1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
#pragma once
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
namespace torch {
namespace jit {
struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
ProfilingGraphExecutorImpl(
const std::shared_ptr<Graph>& graph,
std::string function_name);
ExecutionPlan getPlanFor(Stack& stack, size_t remaining_bailout_depth)
override;
GraphExecutorState getDebugState() override;
~ProfilingGraphExecutorImpl() override = default;
private:
void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph);
void runProfilingOptimizations(std::shared_ptr<Graph>& graph);
void replaceFallbackGraphWithFallbackFunction(Block* b);
std::unique_ptr<ProfilingRecord> pr_;
c10::optional<ExecutionPlan>
profiling_plan_; // plan to run in order to profiling the code
c10::optional<ExecutionPlan> optimized_plan_;
// this plan is used if getGraphExecutorOptimize is unset
c10::optional<ExecutionPlan> fallback_plan_;
// fallback functions are inserted for tensorexpr fusion groups
// and by specialize_autogradzero. Whenever, at runtime, input
// tensor don't match profiled properties, fallback functions are called
// They are the deoptimized version of the logic in fusion groups
// and/or autograd.
// The fallback functions are owned by a GraphExecutor instance
// They only exist in the optimized graph which is a private property
// of the GraphExecutor and only shared with InterpreterState
std::vector<std::unique_ptr<Function>> fallback_functions_;
c10::optional<size_t> remaining_bailout_depth_;
};
} // namespace jit
} // namespace torch
|