File: MobileModelRunner.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (151 lines) | stat: -rw-r--r-- 5,156 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#pragma once

#include <iostream>
#include <mutex>
#include <sstream>

#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/script.h>

namespace torch {
namespace jit {
namespace mobile {

class MobileModelRunner {
  std::shared_ptr<torch::jit::mobile::Module> module_;

 public:
  explicit MobileModelRunner(std::string const& file_path) {
    module_ = std::make_shared<torch::jit::mobile::Module>(
        torch::jit::_load_for_mobile(file_path));
  }

  MobileModelRunner(
      std::string const& file_path,
      uint64_t module_load_options) {
    std::unordered_map<std::string, std::string> extra_files;
    module_ = std::make_shared<torch::jit::mobile::Module>(
        torch::jit::_load_for_mobile(
            file_path,
            at::Device(at::DeviceType::CPU, 0),
            extra_files,
            module_load_options));
  }

  MobileModelRunner(std::stringstream oss) {
    module_ = std::make_shared<torch::jit::mobile::Module>(
        torch::jit::_load_for_mobile(oss, at::Device(at::DeviceType::CPU, 0)));
  }

  /**
   * Returns true if the list of operators passed in has a Metal GPU operator,
   * and false otherwise.
   *
   */
  static bool set_has_metal_gpu_operators(std::set<std::string> const& op_list);

  /**
   * Fetches the set of root operators in the file "extra/mobile_info.json"
   * within the .ptl archive at location file_path.
   *
   * An exception is thrown if:
   *
   * 1. The file at file_path does not exist, or
   * 2. The contents of extra/mobile_info.json is not a JSON, or
   * 3. The file extra/mobile_info.json does not exist, or
   * 4. The JSON is malformed in some way and the operator list can not be
   * extracted correctly.
   *
   */
  static std::set<std::string> get_operators_from_mobile_info_json(
      std::string const& file_path);

  static std::vector<std::vector<at::IValue>> ivalue_to_bundled_inputs(
      const c10::IValue& bundled_inputs);

  static std::unordered_map<std::string, std::string>
  ivalue_to_bundled_inputs_map(const c10::IValue& bundled_inputs);

  /**
   * Fetches all the bundled inputs of the loaded mobile model.
   *
   * A bundled input itself is of type std::vector<at::IValue> and the
   * elements of this vector<> are the arguments that the "forward"
   * method of the model accepts. i.e. each of the at::IValue is a
   * single argument to the model's "forward" method.
   *
   * The outer vector holds a bundled input. For models with bundled
   * inputs, the outer most vector will have size > 0.
   */
  std::vector<std::vector<at::IValue>> get_all_bundled_inputs();

  /**
   * Fetches all the bundled inputs for all functions of the loaded mobile
   * model.
   *
   * The mapping is from 'function_names' eg 'forward' to bundled inputs for
   * that function
   *
   * A bundled input itself is of type std::vector<at::IValue> and the
   * elements of this vector<> are the arguments that the corresponding
   * method of the model accepts. i.e. each of the at::IValue in the entry
   * for forward is a single argument to the model's "forward" method.
   *
   * The outer vector of each value holds a bundled input. For models with
   * bundled inputs, the outer most vector will have size > 0.
   */
  std::unordered_map<std::string, std::vector<std::vector<at::IValue>>>
  get_many_functions_bundled_inputs();

  /**
   * Returns true if a model possesses get_bundled_inputs_functions_and_info()
   */
  bool has_new_style_bundled_inputs() const {
    return module_->find_method("get_bundled_inputs_functions_and_info") !=
        c10::nullopt;
  }

  /**
   * For each tensor in bundled inputs, call the user-provided function 'func'.
   */
  void for_each_tensor_in_bundled_inputs(
      std::function<void(const ::at::Tensor&)> const& func);

  /**
   * Get the root operators directly called by this model's Bytecode.
   */
  std::set<std::string> get_root_operators() {
    return torch::jit::mobile::_export_operator_list(*module_);
  }

  /**
   * Runs the model against all of the provided inputs using the model's
   * "forward" method. Returns an std::vector<at::IValue>, where each element
   * of the returned vector is one of the return values from calling forward().
   */
  std::vector<at::IValue> run_with_inputs(
      std::vector<std::vector<at::IValue>> const& bundled_inputs);

  /**
   * Runs the model against all of the provided inputs for all the specified
   * function. Returns an std::vector<at::IValue>, where each element
   * of the returned vector is one of the return values from calling the
   * method named "function_name" on this model.
   */
  std::vector<at::IValue> run_with_inputs(
      const std::string& function_name,
      std::vector<std::vector<at::IValue>> const& bundled_inputs) const;

  /**
   * Attempts to run all functions in the passed in list if they exist. All
   * funcs should require no args
   */
  void run_argless_functions(const std::vector<std::string>& functions);
};

} // namespace mobile
} // namespace jit
} // namespace torch