1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/python_compat.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
constexpr auto kInternalModule = "torch.distributed.rpc.internal";
// A macro that grabs the GIL, profiling the acquisition time. The average GIL
// acquisition time will be recorded in RpcAgent's getMetrics().
#define PROFILE_GIL_SCOPED_ACQUIRE \
std::chrono::time_point<std::chrono::high_resolution_clock> startTime; \
auto shouldProfileGIL = \
RpcAgent::getCurrentRpcAgent()->isGILProfilingEnabled(); \
if (shouldProfileGIL) { \
startTime = std::chrono::high_resolution_clock::now(); \
} \
pybind11::gil_scoped_acquire ag; \
if (shouldProfileGIL) { \
auto dur = std::chrono::duration_cast<std::chrono::microseconds>( \
std::chrono::high_resolution_clock::now() - startTime); \
RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur); \
} // NOLINT
// PythonTypeResolver that inherits from Script::Resolver to
// support resolving types together with ScriptTypeParser.
struct PythonTypeResolver : public jit::Resolver {
std::shared_ptr<jit::SugaredValue> resolveValue(
const std::string& /* unused */,
torch::jit::GraphFunction& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");
}
TypePtr resolveType(
const std::string& name,
const jit::SourceRange& /* unused */) override {
if (name == "PyObject") {
return PyObjectType::get();
}
return PythonRpcHandler::getInstance().jitCompilationUnit()->get_type(name);
}
};
py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
TORCH_CHECK(
py::isinstance<py::function>(fn),
"attribute ",
name,
" is not a function");
return fn;
}
void cleanupPyObj(py::object& obj) {
obj.dec_ref();
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
// decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
obj.ptr() = nullptr;
}
} // namespace
void PythonRpcHandler::init() {
std::lock_guard<std::mutex> guard(init_lock_);
if (!initialized_) {
PROFILE_GIL_SCOPED_ACQUIRE;
py::object rpcInternal = py::module::import(kInternalModule);
py::object rpcApi = py::module::import("torch.distributed.rpc.api");
py::object rrefProxy =
py::module::import("torch.distributed.rpc.rref_proxy");
pyRunFunction_ = getFunction(rpcInternal, "_run_function");
pySerialize_ = getFunction(rpcInternal, "serialize");
pyDeserialize_ = getFunction(rpcInternal, "deserialize");
pyHandleException_ = getFunction(rpcInternal, "_handle_exception");
rrefTypeFunctions_.onOwner_ = getFunction(rpcApi, "_rref_typeof_on_owner");
rrefTypeFunctions_.onUser_ = getFunction(rpcApi, "_rref_typeof_on_user");
rrefProxyFunctions_.rpcSync_ = getFunction(rpcApi, "rpc_sync");
rrefProxyFunctions_.rpcAsync_ = getFunction(rpcApi, "rpc_async");
rrefProxyFunctions_.remote_ = getFunction(rpcApi, "remote");
rrefProxyFunctions_.rrefProxyCtor_ = getFunction(rrefProxy, "RRefProxy");
jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>());
initialized_ = true;
}
}
PythonRpcHandler::PythonRpcHandler() : initialized_(false) {}
void PythonRpcHandler::cleanup() {
std::lock_guard<std::mutex> guard(init_lock_);
PROFILE_GIL_SCOPED_ACQUIRE;
cleanupPyObj(pyRunFunction_);
cleanupPyObj(pySerialize_);
cleanupPyObj(pyDeserialize_);
cleanupPyObj(pyHandleException_);
cleanupPyObj(rrefProxyFunctions_.rpcSync_);
cleanupPyObj(rrefProxyFunctions_.rpcAsync_);
cleanupPyObj(rrefProxyFunctions_.remote_);
cleanupPyObj(rrefProxyFunctions_.rrefProxyCtor_);
jitCompilationUnit_ = nullptr;
typeParser_ = nullptr;
initialized_ = false;
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
// A thread could hold GIL when calling PythonRpcHandler::getInstance(),
// meantime another thread could have been doing static data
// initialization by calling `new PythonRpcHandler()`, inside of which GIL is
// also required. Static data initialization is thread-safe, so the thread
// holding the GIL will wait for the other thread to finish static data
// initializating before going forward. Because the initialization can't
// proceed without GIL, there is a deadlock. We ask the calling thread to
// release GIL to avoid this situation.
TORCH_INTERNAL_ASSERT(!PyGILState_Check());
// Leaky singleton to avoid module destructor race.
static PythonRpcHandler* handler = new PythonRpcHandler();
handler->init();
return *handler;
}
std::shared_ptr<torch::jit::CompilationUnit> PythonRpcHandler::
jitCompilationUnit() {
return jitCompilationUnit_;
}
py::object PythonRpcHandler::runPythonUdf(const py::object& pythonUdf) {
PROFILE_GIL_SCOPED_ACQUIRE;
// Throw a descriptive error message if pyRunFunction_ is already cleaned up.
TORCH_INTERNAL_ASSERT(
!pyRunFunction_.is_none(),
"Cannot run python UDF since pyRunFunction_ is None. Check if python RPC "
"handler is already cleaned up.");
return pyRunFunction_(pythonUdf);
}
SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
PROFILE_GIL_SCOPED_ACQUIRE;
py::tuple t = pySerialize_(obj);
return SerializedPyObj(
t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
}
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
// NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
// function fails. Functions consuming the result needs to handle such error
// properly.
return pyDeserialize_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
void PythonRpcHandler::handleException(const py::object& obj) {
PROFILE_GIL_SCOPED_ACQUIRE;
pyHandleException_(obj);
}
void PythonRpcHandler::handleExceptionGILHeld(const py::object& obj) {
TORCH_CHECK(PyGILState_Check(), "GIL should be held");
pyHandleException_(obj);
}
bool PythonRpcHandler::isRemoteException(const py::object& obj) {
PROFILE_GIL_SCOPED_ACQUIRE;
auto type = obj.get_type();
auto moduleName = type.attr("__module__").cast<std::string>();
auto qualName = type.attr("__qualname__").cast<std::string>();
return moduleName.compare(kInternalModule) == 0 &&
qualName.compare("RemoteException") == 0;
}
TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
return typeParser_->parseType(type_str);
}
const PythonRpcHandler::RRefProxyFunctions& PythonRpcHandler::
getRRefProxyFunctions() const {
return rrefProxyFunctions_;
}
const PythonRpcHandler::RRefTypeFunctions& PythonRpcHandler::
getRRefTypeFunctions() const {
return rrefTypeFunctions_;
}
} // namespace rpc
} // namespace distributed
} // namespace torch
|