1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
#pragma once
#include <c10/util/Optional.h>
#include <memory>
#include <vector>
#include <ATen/ThreadLocalState.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/frontend/source_range.h>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
#endif
C10_DECLARE_bool(torch_jit_disable_warning_prints);
C10_DECLARE_bool(torch_jit_enable_rethrow_caught_exception);
namespace at {
class Tensor;
TORCH_API void launch(std::function<void()> func);
} // namespace at
namespace c10 {
struct IValue;
struct OperatorName;
} // namespace c10
namespace torch {
namespace jit {
// The interpreter run Graphs with Tensor inputs and Tensor outputs
// a separate component in the autograd handles unwrapping and wrapping
// variable objects for use in the interpreter.
namespace interpreter {
struct CodeImpl;
}
struct Node;
struct GraphExecutor;
struct InterpreterStateImpl;
struct Graph;
struct Node;
struct Instruction;
using Stack = std::vector<c10::IValue>;
using c10::ivalue::Future;
using TaskLauncher = std::function<void(std::function<void()>)>;
struct TORCH_API Code {
Code() = default;
explicit Code(interpreter::CodeImpl* pImpl);
// remaining_bailout_depth is irrelevant in a `Code` object unless the `Code`
// is directly created by `GraphExecutor` in which case it's likely to contain
// `prim::BailOut`s to control the maximum depth of bailout chains
explicit Code(
const std::shared_ptr<Graph>& graph,
std::string function_name,
size_t remaining_bailout_depth = 0);
~Code();
const std::vector<GraphExecutor*>& grad_executors();
const std::vector<GraphExecutor*>& diff_graph_op_executors();
explicit operator bool() const {
return pImpl != nullptr;
}
size_t num_inputs() const;
size_t num_outputs() const;
size_t num_bailouts() const;
const std::vector<c10::IValue>& constant_table() const;
const std::vector<c10::TypePtr>& type_table() const;
const std::vector<Instruction>& instructions() const;
const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
const;
const std::vector<Node*>& instructions_source() const;
void request_bailout(size_t index);
size_t register_size() const;
private:
std::shared_ptr<interpreter::CodeImpl> pImpl;
friend struct InterpreterStateImpl;
friend std::ostream& operator<<(std::ostream& out, const Code& code);
};
struct TORCH_API MobileCode : Code {
explicit MobileCode(
const std::shared_ptr<Graph>& graph,
std::string function_name,
bool emit_default_input_instructions = true,
bool support_default_args_before_out = true,
bool emit_promoted_ops = true,
size_t remaining_bailout_depth = 0);
~MobileCode();
};
struct InterpreterState {
TORCH_API InterpreterState(
const Code& code,
TaskLauncher taskLauncher = at::launch);
TORCH_API void run(Stack& stack);
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture();
TORCH_API ~InterpreterState();
private:
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
// Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl;
// but intrusive_ptr requires full definition of InterpreterStateImpl,
// which we need to hide in the header.
c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl;
friend struct InterpreterStateImpl;
};
// Created by wait()
struct Suspend : public std::exception {
const char* what() const noexcept override {
return "Suspend";
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
explicit Suspend(c10::intrusive_ptr<Future> future_)
: future(std::move(future_)) {}
c10::intrusive_ptr<Future> future;
};
// InterpreterContinuation propagates dist_autograd_context_id
// through (and only through) the forward pass manually, other
// thread local settings are propagated with ThreadLocalState
struct InterpreterContinuation {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
InterpreterContinuation(
const InterpreterState& state_,
Stack stack_,
int64_t dist_autograd_context_id = 0,
c10::optional<at::ThreadLocalState> tls_state = c10::nullopt)
: state(state_),
stack(std::move(stack_)),
tls_state_(std::move(tls_state)) {
#ifdef USE_DISTRIBUTED
dist_autograd_context_id_ = dist_autograd_context_id;
#endif
}
void operator()();
private:
InterpreterState state;
Stack stack;
c10::optional<at::ThreadLocalState> tls_state_ = c10::nullopt;
#ifdef USE_DISTRIBUTED
int64_t dist_autograd_context_id_;
#endif
};
// what is the tensors type, including state from the current execution context
// that modifies how the tensor behaves. For instance if no_grad is enabled
// this will cause the TensorType to have requires_grad=False.
TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext(
const at::Tensor& t);
// current (TLS) TorchScript interpreter callstack
TORCH_API std::vector<StackEntry> currentCallstack();
TORCH_API std::vector<std::string> currentModuleHierarchy();
} // namespace jit
} // namespace torch
C10_CLANG_DIAGNOSTIC_POP()
|