File: optimize_for_mobile.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (106 lines) | stat: -rw-r--r-- 3,807 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <string>
#include <sstream>
#include "torch/script.h"
#include "torch/csrc/jit/api/module.h"
#include <torch/csrc/jit/passes/metal_rewrite.h>
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
#include "torch/csrc/jit/serialization/export.h"

C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(
    output,
    "",
    "Name of the output model to be saved.");
C10_DEFINE_string(backend, "", "The backend to be optimized");
C10_DEFINE_string(preserved_methods, "", "Methods to be preserved")

int main(int argc, char** argv) {
  c10::SetUsageMessage(
    "\nRun optimization pass for pytorch model. Example usage:\n"
    "./optimize_for_mobile"
    " --model=<model_file>"
    " [--output=<output_file_name>]"
    " [--backend=<cpu|vulkan|metal>]"
    " [--preserved_methods=<method_names>]"
  );

  if (!c10::ParseCommandLineFlags(&argc, &argv)) {
    std::cerr << "Failed to parse command line flags!" << std::endl;
    std::cout << c10::UsageMessage() << std::endl;
    return 1;
  }

  CAFFE_ENFORCE(FLAGS_model != "", c10::UsageMessage());

  std::string output_model_name =
    FLAGS_model.substr(0, FLAGS_model.find(".")) + "_optimized.ptl";

  if (FLAGS_output != "") {
    output_model_name = FLAGS_output;
  }

  std::vector<std::string> preserved_methods;
  if(FLAGS_preserved_methods != ""){
    std::stringstream ss(FLAGS_preserved_methods);
    std::string m;
    while(std::getline(ss, m, ';')){
      if(m != ""){
        preserved_methods.emplace_back(std::move(m));
      }
    }
    std::cout<<"The following methods will be preserved:"<<std::endl;
    for(auto& str : preserved_methods){
      std::cout<<str<<std::endl;
    }
  }

  auto module = torch::jit::load(FLAGS_model);
  auto ops = torch::jit::export_opnames(module);
  std::cout << "\npt_operator_library(" << std::endl;
  std::cout << "\tname = \"old_op_library\"," << std::endl;
  std::cout << "\tops = [" << std::endl;
  for (auto const& op: ops) {
    std::cout << "\t\t\"" << op << "\"," << std::endl;
  }
  std::cout << "\t],\n)\n" << std::endl;

  torch::jit::Module optimized_module;
  if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
    optimized_module = torch::jit::optimizeForMobile(module);
  } else if (FLAGS_backend == "vulkan") {
    optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
  } else if (FLAGS_backend == "metal"){
    optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
  }else{
    CAFFE_ENFORCE(false, "Unknown backend: " + FLAGS_backend);
  }
  auto new_ops = torch::jit::export_opnames(optimized_module);
  std::cout << "\npt_operator_library(" << std::endl;
  std::cout << "\tname = \"new_op_library\"," << std::endl;
  std::cout << "\tops = [" << std::endl;
  for (auto const& op: new_ops) {
    std::cout << "\t\t\"" << op << "\"," << std::endl;
  }
  std::cout << "\t],\n)\n" << std::endl;
  optimized_module._save_for_mobile(output_model_name);
  std::cout << "The optimized model for lite interpreter was saved to " << output_model_name << std::endl;
  return 0;
}