1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace rpc {
const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
const std::string ScriptCall::ATEN_PREFIX_("aten::");
ScriptCall::ScriptCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& stack)
: op_(std::move(op)), stack_(stack), isAsyncExecution_(false) {}
ScriptCall::ScriptCall(
const c10::QualifiedName& qualifiedName,
std::vector<at::IValue>&& stack,
const bool isAsyncExecution)
: qualifiedName_(qualifiedName),
stack_(stack),
isAsyncExecution_(isAsyncExecution) {}
bool ScriptCall::hasOp() const {
return op_ ? true : false;
}
std::shared_ptr<Operator> ScriptCall::op() const {
return *op_;
}
bool ScriptCall::hasQualifiedName() const {
return qualifiedName_ ? true : false;
}
const c10::QualifiedName& ScriptCall::qualifiedName() const {
return *qualifiedName_;
}
const std::vector<at::IValue>& ScriptCall::stack() const {
return stack_;
}
std::vector<at::IValue>& ScriptCall::stackRef() {
return stack_;
}
void ScriptCall::toIValues(std::vector<at::IValue>& ivalues) const {
for (auto& value : stack_) {
ivalues.push_back(value);
}
if (hasOp()) {
TORCH_CHECK(
!hasQualifiedName(),
"It is builtin operator call, qualifiedName_ should not be set.");
// TODO: replace this with a real overload_name when FunctionSchema supports
// that.
ivalues.emplace_back(toString((*op_)->schema()));
// insert qualified name
auto opName = (*op_)->schema().name();
TORCH_CHECK(
opName.find("::") == opName.rfind("::") &&
opName.rfind(ATEN_PREFIX_) == 0,
"Unexpected operator name ",
opName);
// aten::add -> torch.ops.aten.add
opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
ivalues.emplace_back(std::move(opName));
} else if (hasQualifiedName()) {
ivalues.emplace_back(isAsyncExecution());
TORCH_CHECK(
!hasOp(),
"It is TorchScript function call, operator should not be set.");
ivalues.emplace_back((*qualifiedName_).qualifiedName());
} else {
TORCH_INTERNAL_ASSERT(
false,
"Either builtin operator or TorchScript function name should be set.");
}
}
std::unique_ptr<ScriptCall> ScriptCall::fromIValues(
std::vector<at::IValue>& ivalues) {
// Last element in the vector is always qualifiedName for both
// builitin operator and TorchScript function
// If the qualifiedName is not a builtin operator name, then treat it
// as TorchScript function name
const std::string& qualifiedName = ivalues.back().toStringRef();
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
ivalues.pop_back();
const std::string& str_schema = ivalues.back().toStringRef();
auto op = matchOperator(str_schema);
ivalues.pop_back();
// remove str_schema from ivalues
return std::make_unique<ScriptCall>(op, std::move(ivalues));
} else {
ivalues.pop_back();
bool isAsyncExecution = ivalues.back().toBool();
ivalues.pop_back();
return std::make_unique<ScriptCall>(
c10::QualifiedName(qualifiedName),
std::move(ivalues),
isAsyncExecution);
}
}
c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
std::vector<IValue> ivalues;
toIValues(ivalues);
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
}
std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTupleRef().elements().vec();
return fromIValues(values);
}
std::shared_ptr<Operator> ScriptCall::matchOperator(
const std::string& str_schema) {
// TODO: This is a temporary solution. We should pass enough information to
// allow deterministically matched to one operator.
// extract symbol from the schema
auto schema = torch::jit::parseSchema(str_schema);
auto symbol = at::Symbol::fromQualString(schema.name());
for (auto op : torch::jit::getAllOperatorsFor(symbol)) {
if (toString(op->schema()).compare(str_schema) == 0) {
return op;
}
}
TORCH_CHECK(false, "Cannot find matching operator for schema ", str_schema);
}
} // namespace rpc
} // namespace distributed
} // namespace torch
|