File: function.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (123 lines) | stat: -rw-r--r-- 3,508 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <torch/csrc/jit/mobile/function.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/custom_class_detail.h>

namespace torch {
namespace jit {

char const* toString(OpCode op);
namespace mobile {
Function::Function(c10::QualifiedName name)
    : name_(name), code_(std::make_shared<Code>()) {}

const c10::QualifiedName& Function::qualname() const {
  return name_;
}

const std::string& Function::name() const {
  return name_.name();
}

void Function::append_instruction(OpCode op, int X, int N) {
  TORCH_CHECK(
      op != CREATE_OBJECT,
      "CREATE_OBJECT is not supported in mobile module. ",
      "Workaround: instead of using arbitrary class type (class Foo()), ",
      "define a pytorch class (class Foo(torch.nn.Module)).");
  TORCH_CHECK(
      isOpSupportedInMobile(op),
      toString(op),
      " is not supported in mobile module.");
  code_->instructions_.emplace_back(op, X, N);
}

bool Function::append_operator(
    const std::string& name,
    const std::string& overload_name,
    int64_t model_version) {
  // Keep the original opname in code_
  code_->op_names_.emplace_back(name, overload_name);
  auto opname = code_->op_names_.back();

  auto opname_c10 = opname;
  std::function<void(Stack&)> fn;

  auto jit_op = findOperatorFor(opname);
  if (jit_op) {
    fn = [jit_op](Stack& stack) { jit_op->getOperation()(&stack); };
  } else {
    auto op = c10::Dispatcher::singleton().findSchema(opname_c10);
    if (op.has_value()) {
      fn = [op](Stack& stack) { op->callBoxed(&stack); };
    } else {
      return false;
    }
  }

  if (model_version == 0x3L &&
      model_version < caffe2::serialize::kProducedBytecodeVersion &&
      opname == c10::OperatorName("aten::_convolution", "")) {
    // A default-value argument will be added in
    // https://github.com/pytorch/pytorch/pull/40737. This wrapper is used to
    // handle backward compatibility, where there is no default bool value in
    // old models.
    fn = [fn](Stack& stack) {
      stack.push_back(true);
      fn(stack);
    };
  }

  code_->operators_.emplace_back(fn);
  return true;
}

void Function::set_module_debug_info_list_size(size_t size) {
  pc_to_module_debug_info_.resize(size);
  for (size_t i = 0; i < size; ++i) {
    pc_to_module_debug_info_[i] = "<no module info>";
  }
}

void Function::set_module_info(const std::string& module_info, size_t pc) {
  TORCH_CHECK(
      pc < pc_to_module_debug_info_.size(),
      "Module debug info index out of boundary.");
  pc_to_module_debug_info_[pc] = module_info;
}

void Function::append_constant(const c10::IValue& constant) {
  code_->constants_.push_back(constant);
}

void Function::append_type(const at::TypePtr& type) {
  code_->types_.push_back(type);
}

void Function::set_register_size(size_t size) {
  code_->register_size_ = size;
}

std::string Function::get_module_debug_info(size_t pc) const {
  TORCH_CHECK(
      pc < pc_to_module_debug_info_.size(),
      "Module debug info index out of boundary.");
  return pc_to_module_debug_info_[pc];
}

bool Function::run(Stack& stack) const {
  InterpreterState interp_state(code_);
  return interp_state.run(stack);
}

c10::IValue Function::operator()(Stack& stack) {
  InterpreterState interp_state(code_);
  interp_state.run(stack);
  return stack.front();
}

} // namespace mobile
} // namespace jit
} // namespace torch