File: script_call.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (158 lines) | stat: -rw-r--r-- 5,051 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/jit/serialization/pickle.h>

namespace torch::distributed::rpc {

const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
const std::string ScriptCall::ATEN_PREFIX_("aten::");

ScriptCall::ScriptCall(
    std::shared_ptr<Operator> op,
    // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
    std::vector<at::IValue>&& stack)
    : op_(std::move(op)), stack_(stack), isAsyncExecution_(false) {}

ScriptCall::ScriptCall(
    const c10::QualifiedName& qualifiedName,
    // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
    std::vector<at::IValue>&& stack,
    const bool isAsyncExecution)
    : qualifiedName_(qualifiedName),
      stack_(stack),
      isAsyncExecution_(isAsyncExecution) {}

bool ScriptCall::hasOp() const {
  return op_ ? true : false;
}

std::shared_ptr<Operator> ScriptCall::op() const {
  return op_.value();
}

bool ScriptCall::hasQualifiedName() const {
  return qualifiedName_ ? true : false;
}

const c10::QualifiedName& ScriptCall::qualifiedName() const {
  return qualifiedName_.value();
}

const std::vector<at::IValue>& ScriptCall::stack() const {
  return stack_;
}

std::vector<at::IValue>& ScriptCall::stackRef() {
  return stack_;
}

void ScriptCall::toIValues(std::vector<at::IValue>& ivalues) const {
  for (auto& value : stack_) {
    ivalues.push_back(value);
  }

  if (hasOp()) {
    TORCH_CHECK(
        !hasQualifiedName(),
        "It is builtin operator call, qualifiedName_ should not be set.");
    // TODO: replace this with a real overload_name when FunctionSchema supports
    // that.
    ivalues.emplace_back(toString((*op_)->schema()));
    // insert qualified name
    auto opName = (*op_)->schema().name();
    TORCH_CHECK(
        opName.find("::") == opName.rfind("::") &&
            opName.rfind(ATEN_PREFIX_) == 0,
        "Unexpected operator name ",
        opName);
    // aten::add -> torch.ops.aten.add
    opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
    ivalues.emplace_back(std::move(opName));
  } else if (hasQualifiedName()) {
    ivalues.emplace_back(isAsyncExecution());
    TORCH_CHECK(
        !hasOp(),
        "It is TorchScript function call, operator should not be set.");
    ivalues.emplace_back((*qualifiedName_).qualifiedName());
  } else {
    TORCH_INTERNAL_ASSERT(
        false,
        "Either builtin operator or TorchScript function name should be set.");
  }
}

std::unique_ptr<ScriptCall> ScriptCall::fromIValues(
    std::vector<at::IValue>& ivalues) {
  TORCH_INTERNAL_ASSERT(
      ivalues.size() > 1,
      "At least 2 IValues are required to build a ScriptCall.");

  // Last element in the vector is always qualifiedName for both
  // builtin operator and TorchScript function
  // If the qualifiedName is not a builtin operator name, then treat it
  // as TorchScript function name
  std::string qualifiedName = ivalues.back().toStringRef();

  if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
    ivalues.pop_back();
    const std::string& str_schema = ivalues.back().toStringRef();
    auto op = matchOperator(str_schema);

    ivalues.pop_back();
    // remove str_schema from ivalues
    return std::make_unique<ScriptCall>(op, std::move(ivalues));
  } else {
    ivalues.pop_back();
    bool isAsyncExecution = ivalues.back().toBool();
    ivalues.pop_back();
    return std::make_unique<ScriptCall>(
        c10::QualifiedName(qualifiedName),
        std::move(ivalues),
        isAsyncExecution);
  }
}

c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
  std::vector<IValue> ivalues;
  toIValues(ivalues);

  std::vector<torch::Tensor> tensor_table;
  auto payload = jit::pickle(
      c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);

  return c10::make_intrusive<Message>(
      std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
}

std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
  auto payload = static_cast<const char*>(message.payload().data());
  auto payload_size = message.payload().size();
  auto value = jit::unpickle(
      payload,
      payload_size,
      *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
      message.tensors());

  auto values = value.toTupleRef().elements().vec();
  return fromIValues(values);
}

std::shared_ptr<Operator> ScriptCall::matchOperator(
    const std::string& str_schema) {
  // TODO: This is a temporary solution. We should pass enough information to
  // allow deterministically matched to one operator.

  // extract symbol from the schema
  auto schema = torch::jit::parseSchema(str_schema);
  auto symbol = at::Symbol::fromQualString(schema.name());

  for (auto op : torch::jit::getAllOperatorsFor(symbol)) {
    if (toString(op->schema()) == str_schema) {
      return op;
    }
  }

  TORCH_CHECK(false, "Cannot find matching operator for schema ", str_schema);
}

} // namespace torch::distributed::rpc