1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
#pragma once
#include <c10/util/Optional.h>
#include <memory>
#include <vector>
#include <ATen/ThreadLocalState.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/frontend/source_range.h>
namespace at {
class Tensor;
}
namespace c10 {
struct IValue;
struct OperatorName;
} // namespace c10
namespace torch {
namespace jit {
// The interpreter run Graphs with Tensor inputs and Tensor outputs
// a separate component in the autograd handles unwrapping and wrapping
// variable objects for use in the interpreter.
struct Node;
struct GraphExecutor;
struct CodeImpl;
struct InterpreterStateImpl;
struct Graph;
struct Node;
struct Instruction;
using Stack = std::vector<c10::IValue>;
using c10::ivalue::Future;
struct TORCH_API Code {
Code() : pImpl(nullptr) {}
// remaining_bailout_depth is irrelevant in a `Code` object unless the `Code`
// is directly created by `GraphExecutor` in which case it's likely to contain
// `prim::BailOut`s to control the maximum depth of bailout chains
explicit Code(
const std::shared_ptr<Graph>& graph,
std::string function_name,
size_t remaining_bailout_depth = 0);
~Code();
const std::vector<GraphExecutor*>& grad_executors();
explicit operator bool() const {
return pImpl != nullptr;
}
size_t num_inputs() const;
size_t num_outputs() const;
size_t num_bailouts() const;
const std::vector<c10::IValue>& constant_table() const;
const std::vector<c10::TypePtr>& type_table() const;
const std::vector<Instruction>& instructions() const;
const std::vector<Node*>& instructions_source() const;
void request_bailout(size_t index);
size_t register_size() const;
private:
std::shared_ptr<CodeImpl> pImpl;
friend struct InterpreterStateImpl;
friend std::ostream& operator<<(std::ostream& out, const Code& code);
};
struct InterpreterState {
TORCH_API InterpreterState(const Code& code);
TORCH_API void run(Stack& stack);
c10::intrusive_ptr<Future> runAsync(Stack& stack);
c10::intrusive_ptr<Future> getFuture();
TORCH_API ~InterpreterState();
private:
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
// Ideally we should use c10::intrusive_ptr<InterpreterStateImpl> for pImpl;
// but intrusive_ptr requires full definition of InterpreterStateImpl,
// which we need to hide in the header.
c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl;
friend struct InterpreterStateImpl;
};
// Created by wait()
struct Suspend : public std::exception {
const char* what() const noexcept override {
return "Suspend";
}
explicit Suspend(c10::intrusive_ptr<Future> future_)
: future(std::move(future_)) {}
c10::intrusive_ptr<Future> future;
};
// InterpreterContinuation propagates dist_autograd_context_id
// through (and only through) the forward pass manually, other
// thread local settings are propagated with ThreadLocalState
struct InterpreterContinuation {
InterpreterContinuation(
InterpreterState state_,
Stack stack_,
int64_t dist_autograd_context_id = 0,
c10::optional<at::ThreadLocalState> tls_state = c10::nullopt)
: state(state_),
stack(std::move(stack_)),
tls_state_(std::move(tls_state)) {
#ifdef USE_DISTRIBUTED
dist_autograd_context_id_ = dist_autograd_context_id;
#endif
}
void operator()();
private:
InterpreterState state;
Stack stack;
c10::optional<at::ThreadLocalState> tls_state_ = c10::nullopt;
#ifdef USE_DISTRIBUTED
int64_t dist_autograd_context_id_;
#endif
};
// what is the tensors type, including state from the current execution context
// that modifies how the tensor behaves. For instance if no_grad is enabled
// this will cause the TensorType to have requires_grad=False.
TORCH_API at::TensorTypePtr tensorTypeInCurrentExecutionContext(
const at::Tensor& t);
// current (TLS) TorchScript interpreter callstack
TORCH_API std::vector<StackEntry> currentCallstack();
} // namespace jit
} // namespace torch
|