1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
#pragma once
#include <iostream>
#include <mutex>
#include <sstream>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/script.h>
namespace torch {
namespace jit {
namespace mobile {
class MobileModelRunner {
std::shared_ptr<torch::jit::mobile::Module> module_;
public:
explicit MobileModelRunner(std::string const& file_path) {
module_ = std::make_shared<torch::jit::mobile::Module>(
torch::jit::_load_for_mobile(file_path));
}
MobileModelRunner(
std::string const& file_path,
uint64_t module_load_options) {
std::unordered_map<std::string, std::string> extra_files;
module_ = std::make_shared<torch::jit::mobile::Module>(
torch::jit::_load_for_mobile(
file_path,
at::Device(at::DeviceType::CPU, 0),
extra_files,
module_load_options));
}
MobileModelRunner(std::stringstream oss) {
module_ = std::make_shared<torch::jit::mobile::Module>(
torch::jit::_load_for_mobile(oss, at::Device(at::DeviceType::CPU, 0)));
}
/**
* Returns true if the list of operators passed in has a Metal GPU operator,
* and false otherwise.
*
*/
static bool set_has_metal_gpu_operators(std::set<std::string> const& op_list);
/**
* Fetches the set of root operators in the file "extra/mobile_info.json"
* within the .ptl archive at location file_path.
*
* An exception is thrown if:
*
* 1. The file at file_path does not exist, or
* 2. The contents of extra/mobile_info.json is not a JSON, or
* 3. The file extra/mobile_info.json does not exist, or
* 4. The JSON is malformed in some way and the operator list can not be
* extracted correctly.
*
*/
static std::set<std::string> get_operators_from_mobile_info_json(
std::string const& file_path);
static std::vector<std::vector<at::IValue>> ivalue_to_bundled_inputs(
const c10::IValue& bundled_inputs);
static std::unordered_map<std::string, std::string>
ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs);
/**
* Fetches all the bundled inputs of the loaded mobile model.
*
* A bundled input itself is of type std::vector<at::IValue> and the
* elements of this vector<> are the arguments that the "forward"
* method of the model accepts. i.e. each of the at::IValue is a
* single argument to the model's "forward" method.
*
* The outer vector holds a bundled input. For models with bundled
* inputs, the outer most vector will have size > 0.
*/
std::vector<std::vector<at::IValue>> get_all_bundled_inputs();
/**
* Fetches all the bundled inputs for all functions of the loaded mobile
* model.
*
* The mapping is from 'function_names' eg 'forward' to bundled inputs for
* that function
*
* A bundled input itself is of type std::vector<at::IValue> and the
* elements of this vector<> are the arguments that the corresponding
* method of the model accepts. i.e. each of the at::IValue in the entry
* for forward is a single argument to the model's "forward" method.
*
* The outer vector of each value holds a bundled input. For models with
* bundled inputs, the outer most vector will have size > 0.
*/
std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
get_many_functions_bundled_inputs();
/**
* Returns true if a model possesses get_bundled_inputs_functions_and_info()
*/
bool has_new_style_bundled_inputs() const {
return module_->find_method("get_bundled_inputs_functions_and_info") !=
c10::nullopt;
}
/**
* For each tensor in bundled inputs, call the user-provided function 'func'.
*/
void for_each_tensor_in_bundled_inputs(
std::function<void(const ::at::Tensor&)> const& func);
/**
* Get the root operators directly called by this model's Bytecode.
*/
std::set<std::string> get_root_operators() {
return torch::jit::mobile::_export_operator_list(*module_);
}
/**
* Runs the model against all of the provided inputs using the model's
* "forward" method. Returns an std::vector<at::IValue>, where each element
* of the returned vector is one of the return values from calling forward().
*/
std::vector<at::IValue> run_with_inputs(
std::vector<std::vector<at::IValue>> const& bundled_inputs);
/**
* Runs the model against all of the provided inputs for all the specified
* function. Returns an std::vector<at::IValue>, where each element
* of the returned vector is one of the return values from calling the
* method named "function_name" on this model.
*/
std::vector<at::IValue> run_with_inputs(
const std::string& function_name,
std::vector<std::vector<at::IValue>> const& bundled_inputs) const;
/**
* Attempts to run all functions in the passed in list if they exist. All
* funcs should require no args
*/
void run_argless_functions(const std::vector<std::string>& functions);
};
} // namespace mobile
} // namespace jit
} // namespace torch
|