1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
#include <sstream>
#include <string>
#include <ATen/core/jit_type.h>
#include <c10/core/ScalarType.h>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_preprocess.h>
#include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/script.h>
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(model_name, "", "The name of the model.");
C10_DEFINE_string(model_version, "", "The version of the model.");
C10_DEFINE_string(
input_dims,
"",
"The dimensions of input TensorCPUs using comma separated numbers."
"If multiple inputs needed, use semicolon to separate "
"the dimension of different tensors.");
C10_DEFINE_string(
input_types,
"float",
"The dtype of input TensorCPUs."
"If multiple inputs needed, use semicolon to separate "
"the dtype of different tensors."
"Supported dtypes: float, int64, uint8");
C10_DEFINE_string(
input_memory_formats,
"",
"Input memory format."
"If multiple inputs needed, use semicolon to separate."
"Supported values: contiguous, channels_last");
C10_DEFINE_string(
dynamic_dims,
"",
"Comma separated dimensions of input tensors that can be dynamic");
C10_DEFINE_string(method_name, "forward", "The name of the method.");
C10_DEFINE_string(
output_llvm,
"",
"Name of the output llvm assembly to be saved.");
C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
namespace {
std::vector<std::string> split(
char separator,
const std::string& string,
bool ignore_empty = true) {
std::vector<std::string> pieces;
std::stringstream ss(string);
std::string item;
while (getline(ss, item, separator)) {
if (!ignore_empty || !item.empty()) {
pieces.push_back(std::move(item));
}
}
return pieces;
}
c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
c10::Dict<c10::IValue, c10::IValue> compile_spec(
c10::StringType::get(), c10::AnyType::get());
c10::Dict<c10::IValue, c10::IValue> method_spec(
c10::StringType::get(), c10::AnyType::get());
method_spec.insert("sizes", FLAGS_input_dims);
method_spec.insert("types", FLAGS_input_types);
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
method_spec.insert("asmfile", FLAGS_output_llvm);
method_spec.insert("model_name", FLAGS_model_name);
method_spec.insert("model_version", FLAGS_model_version);
compile_spec.insert(FLAGS_method_name, method_spec);
return compile_spec;
}
} // namespace
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run NNC AOT compiler for pytorch model. Example usage:\n"
"build/bin/aot_model_compiler"
" --model=<model file>"
" --model_name=<model name>"
" --model_version=<model version>"
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
" --input_types=<input dtypes like 'float;float'>"
" --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
" [--method_name=<method name>]"
" [--output_llvm=<llvm assembly output file path>]"
" [--output_model=<output model file path>]");
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
std::cout << c10::UsageMessage() << std::endl;
return 1;
}
CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
const auto dims_size = split(';', FLAGS_input_dims).size();
CAFFE_ENFORCE(
dims_size == split(';', FLAGS_input_types).size(),
"Number of input_dims and input_types should be the same");
const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
CAFFE_ENFORCE(
mem_formats_size == 0 || mem_formats_size == dims_size,
"Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
if (FLAGS_output_llvm.empty()) {
FLAGS_output_llvm =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
}
std::string output_model_name = FLAGS_output_model;
if (output_model_name.empty()) {
output_model_name =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
}
auto m = torch::jit::load(FLAGS_model);
m.eval();
auto frozen_m = torch::jit::freeze_module(m.clone());
auto compile_spec = createCompileSpec();
auto any_dict_ty =
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
auto compiled_module = torch::jit::detail::codegen_backend_module(
"nnc", frozen_m, compile_spec, any_dict_ty);
compiled_module._save_for_mobile(output_model_name);
std::cout << "The compiled model was saved to " << output_model_name
<< std::endl;
return 0;
}
|