1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
#include <memory>
#include <ATen/nnapi/nnapi_bind.h>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
namespace torch {
namespace jit {
// Implementation of Android NNAPI Backend delegate
// The Android Neural Networks API (NNAPI) is an Android C API designed
// for running computationally intensive operations for machine learning on
// Android devices. The API is available on all Android devices running
// Android 8.1 (API level 27) or higher.
// Implementation is reflective of caffe2/torch/backends/_nnapi/prepare.py
// NnapiModule.forward()
class NnapiBackend : public PyTorchBackendInterface {
public:
// Constructor.
explicit NnapiBackend() = default;
~NnapiBackend() override = default;
bool is_available() override {
return true;
}
c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
// Wrap processed in dictionary: {"forward": processed}
auto dict = processed.toGenericDict();
c10::Dict<c10::IValue, c10::IValue> handles(
c10::StringType::get(), c10::AnyType::get());
handles.insert("forward", dict);
return c10::impl::toGenericDict(handles);
}
c10::impl::GenericList execute(
c10::IValue handle,
c10::impl::GenericList inputs) override {
// Convert inputs to Tensors
c10::List<at::Tensor> tensorInp;
for (c10::IValue element : inputs) {
tensorInp.push_back(element.toTensor());
}
// Lazily call init()
if (comp_ == nullptr) {
init(handle, tensorInp);
}
TORCH_CHECK(comp_ != nullptr)
c10::List<at::Tensor> outputs;
for (at::Tensor out : out_templates_) {
outputs.push_back(at::empty_like(out));
}
// Adjust input memory formats
auto dict = handle.toGenericDict();
auto inp_mem_fmts = dict.at("inp_mem_fmts").toIntList();
TORCH_CHECK(tensorInp.size() == inp_mem_fmts.size());
std::vector<at::Tensor> fixed_inputs;
for (auto i = 0U; i < tensorInp.size(); i++) {
int fmt = inp_mem_fmts[i];
// These constants match the values in DimOrder in serializer.py
// 0: NCHW, 1: NHWC
// TODO: See if it's possible to use those directly.
if (fmt == 0) {
fixed_inputs.push_back(tensorInp.get(i).contiguous());
} else if (fmt == 1) {
fixed_inputs.push_back(
tensorInp.get(i).permute({0, 2, 3, 1}).contiguous());
} else {
TORCH_CHECK(false, "Invalid mem_fmt");
}
}
comp_->run(fixed_inputs, outputs.vec());
// Adjust output memory formats
auto out_mem_fmts = dict.at("out_mem_fmts").toIntList();
TORCH_CHECK(outputs.size() == out_mem_fmts.size());
for (auto i = 0U; i < outputs.size(); i++) {
int fmt = out_mem_fmts[i];
// These constants match the values in DimOrder in serializer.py
// 0: NCHW, 1: NHWC
// TODO: See if it's possible to use those directly.
if (fmt == 1) {
outputs.set(i, outputs.get(i).permute({0, 3, 1, 2}));
} else {
TORCH_CHECK(fmt == 0, "Invalid mem_fmt");
}
}
return c10::impl::toList(outputs);
}
private:
// The following variables are modified by init() during execution,
// and cannot be passed through the handles dictionary
std::unique_ptr<torch::nnapi::bind::NnapiCompilation> comp_;
c10::List<at::Tensor> out_templates_;
// Runs once per model initialization
// Cannot be moved to compile(), because init() requires actual inputs
void init(c10::IValue handle, c10::List<at::Tensor> inputs) {
TORCH_CHECK(comp_ == nullptr);
auto dict = handle.toGenericDict();
// Get ser_model
auto ser_model = dict.at("ser_model").toTensor();
// Load shape computation module
std::stringstream ss;
auto shape_ptr = dict.at("shape_compute_module").toString();
ss.str(*shape_ptr);
auto shape_compute_module = _load_for_mobile(ss);
out_templates_ =
shape_compute_module.run_method("prepare", ser_model, inputs)
.toTensorList();
// Create and initialize NnapiComilation object
comp_ = std::make_unique<torch::nnapi::bind::NnapiCompilation>();
auto weights = dict.at("weights").toTensorVector();
comp_->init(ser_model, weights);
}
};
namespace {
constexpr auto backend_name = "nnapi";
static auto cls = torch::jit::backend<NnapiBackend>(backend_name);
} // namespace
} // namespace jit
} // namespace torch
|