1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
|
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <ATen/core/ivalue.h>
#include <c10/core/ScalarType.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
// Specify the requirements on an input tensor.
// TODO: support input tensor with dynamic shape (PR #54982)
struct TORCH_API InputSpec {
InputSpec() = default;
// Deserialize the spec from an IValue.
explicit InputSpec(const c10::IValue& value);
// Serialize the spec into an IValue.
C10_NODISCARD c10::IValue serialize() const;
// Check whether the input tensor adheres to the spec.
C10_NODISCARD bool validate(const at::Tensor& input) const;
std::vector<int64_t> sizes_;
c10::ScalarType dtype_{c10::ScalarType::Undefined};
};
// Specify the sizes/dtype/... of output tensor to preallocate the output.
// TODO: support the case where kernel allocates output tensors dynamically.
struct TORCH_API OutputSpec {
OutputSpec() = default;
// Deserialize the spec from an IValue.
explicit OutputSpec(const c10::IValue& value);
// Serialize the spec into an IValue.
C10_NODISCARD c10::IValue serialize() const;
// Allocate an output tensor in accordance with the spec.
C10_NODISCARD at::Tensor allocate() const;
std::vector<int64_t> sizes_;
c10::ScalarType dtype_{c10::ScalarType::Undefined};
c10::optional<double> qscale_;
c10::optional<int64_t> qzero_;
};
// Hold the temporary buffers / states needed during the execution.
struct TORCH_API ExecutionState {
ExecutionState() = default;
ExecutionState(const ExecutionState&) = delete;
ExecutionState(ExecutionState&&) = default;
ExecutionState& operator=(const ExecutionState&) = delete;
ExecutionState& operator=(ExecutionState&&) = default;
// Preallocated buffers needed by the NNC kernel.
std::vector<c10::DataPtr> preallocations_;
// The NNC kernel expects the following arguments layout:
// input tensor 1
// ...
// input tensor INPUT_NUM
// output tensor 1
// ...
// output tensor OUTPUT_NUM
// parameter tensor 1
// ...
// parameter tensor PARAM_NUM
// temporary buffer 1
// ...
// temporary buffer BUFFER_NUM
std::vector<void*> arguments_;
};
// Specify how to allocate temporary buffers at initialization.
struct TORCH_API MemoryPlan {
MemoryPlan() = default;
explicit MemoryPlan(const c10::IValue& value);
C10_NODISCARD c10::IValue serialize() const;
void allocate(ExecutionState* state) const;
std::vector<int64_t> buffer_sizes_;
};
// Location of a symbolic shape among dimensions of the inputs
struct TORCH_API SymbolicShapePosition {
SymbolicShapePosition() = default;
SymbolicShapePosition(int64_t input_idx, int64_t dim_idx)
: input_idx_(input_idx), dim_idx_(dim_idx) {}
int64_t input_idx_;
int64_t dim_idx_;
};
// Represents a compiled NNC function which has a 1-1 correspondence with a
// `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function.
class TORCH_API Function {
public:
explicit Function() = default;
// Deserialize from an IValue that is generated by the 'serialize()' method.
explicit Function(const c10::IValue& value);
// Serialize into an IValue.
c10::IValue serialize() const;
// Execute the compiled NNC function.
c10::impl::GenericList run(const c10::impl::GenericList& inputs) const;
// The name of the function as specified in the model code.
c10::QualifiedName name() const {
return name_;
}
void set_name(const c10::QualifiedName& name) {
name_ = name;
}
// The unique id of the generated NNC kernel corresponding to the function.
const std::string& nnc_kernel_id() const {
return nnc_kernel_id_;
}
void set_nnc_kernel_id(const std::string& name) {
nnc_kernel_id_ = name;
}
// The parameters (e.g. weights / bias tensors) to be passed to the generated
// NNC kernel.
const c10::impl::GenericList& parameters() const {
return parameters_;
}
void set_parameters(const c10::impl::GenericList& parameters) {
parameters_ = parameters;
}
const std::vector<InputSpec>& input_specs() const {
return input_specs_;
}
void set_input_specs(const std::vector<InputSpec>& input_specs) {
input_specs_ = input_specs;
}
const std::vector<OutputSpec>& output_specs() const {
return output_specs_;
}
void set_output_specs(const std::vector<OutputSpec>& output_specs) {
output_specs_ = output_specs;
}
const MemoryPlan& memory_plan() const {
return memory_plan_;
}
void set_memory_plan(const MemoryPlan& memory_plan) {
memory_plan_ = memory_plan;
}
const std::vector<SymbolicShapePosition>& sym_shape_positions() const {
return sym_shape_positions_;
}
void set_sym_shape_positions(
const std::vector<SymbolicShapePosition>& sym_shape_pos) {
sym_shape_positions_ = sym_shape_pos;
}
private:
void init_execution_state() const;
c10::QualifiedName name_;
std::string nnc_kernel_id_;
c10::impl::GenericList parameters_{at::AnyType::get()};
std::vector<InputSpec> input_specs_;
std::vector<OutputSpec> output_specs_;
std::vector<SymbolicShapePosition> sym_shape_positions_;
MemoryPlan memory_plan_;
mutable std::unique_ptr<ExecutionState> execution_state_;
};
// CompilationUnit consists of a set of compiled NNC functions. It has a 1-1
// correspondence with a `Module`.
// It's similar as torch::jit::mobile::CompilationUnit.
class TORCH_API CompilationUnit {
public:
CompilationUnit() = default;
CompilationUnit(const CompilationUnit&) = delete;
CompilationUnit(CompilationUnit&&) = default;
CompilationUnit& operator=(const CompilationUnit&) = delete;
CompilationUnit& operator=(CompilationUnit&&) = default;
// Deserialize from an IValue that is generated by the 'serialize()' method.
explicit CompilationUnit(const c10::IValue& value);
// Serialize all registered functions into an IValue. The IValue will be save
// into the compiled TorchScript model file ahead-of-time on the host, and
// will be deserialized at runtime on the target device.
C10_NODISCARD c10::IValue serialize() const;
// Execute a registered function.
C10_NODISCARD c10::impl::GenericList run(
const c10::QualifiedName& function_name,
const c10::impl::GenericList& inputs) const;
// Register a function to the compilation unit.
void register_function(std::unique_ptr<Function> fn);
private:
C10_NODISCARD Function* find_function(const c10::QualifiedName& qn) const;
std::unordered_map<c10::QualifiedName, std::unique_ptr<Function>> functions_;
};
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch
|