1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::distributed::rpc {
// Singleton class provides interface to execute python UDF remote call
// and deserialize the returned results by running python function
// in internal_rpc_utilities.
// The singleton object is constructed at first when RPC agent is
// constructed, where the python function in
// torch/distributed/internal_rpc_utils.py are imported only once.
class PYBIND11_EXPORT PythonRpcHandler {
public:
struct RRefProxyFunctions {
py::object rrefProxyCtor_;
py::object rpcSync_;
py::object rpcAsync_;
py::object remote_;
};
struct RRefTypeFunctions {
py::object onOwner_;
py::object onUser_;
};
static PythonRpcHandler& getInstance();
// Run a pickled Python UDF and return the result py::object
py::object runPythonUdf(const py::object& pythonUdf);
// Serialized a py::object into a string
SerializedPyObj serialize(const py::object& obj);
// Deserialize a string into a py::object
py::object deserialize(const SerializedPyObj& serializedObj);
// Check if obj is RemoteException, then throw it
void handleException(const py::object& obj);
// Alternative if the caller is already holding the GIL.
void handleExceptionGILHeld(const py::object& obj);
// Check if obj is an RemoteException instance.
bool isRemoteException(const py::object& obj);
// Explicitly clean up py::objects to avoid segment faults when
// py::objects with CPython are cleaned up later at program exit
// See similar issues reported https://github.com/pybind/pybind11/issues/1598
// and https://github.com/pybind/pybind11/issues/1493
// Our local tests also caught this segment faults if py::objects are cleaned
// up at program exit. The explanation is: CPython cleans up most critical
// utilities before cleaning up PythonRpcHandler singleton, so when
// PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it
// will crash.
// The solution is to clean up py::objects earlier when Rpc agent join().
// Be note that py::objects can not be cleaned up when Rpc agent is destroyed
// as well, as Rpc agent is global variable and it will have same issue as
// PythonRpcHandler.
void cleanup();
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();
// Parse the string to recover the jit_type, this is used for RRef python
// pickling/unpickling type recovery. The type string inference rule is as
// follows:
// 1. first try to parse if this is primitive types.
// i.e. TensorType, IntType, PyObjectType, etc.
// 2. if not primitive type, we query the python_cu to see if it is a
// class type or interface type registered in python
// We use a ScriptTypeParser instance with custom PythonTypeResolver
// to resolve types according to the above rules.
TypePtr parseTypeFromStr(const std::string& typeStr);
// Return a set of Python functions for RRef helpers.
const RRefProxyFunctions& getRRefProxyFunctions() const;
// Return a set of Python functions to retrieve the type of the object
// referenced by a given RRef.
const RRefTypeFunctions& getRRefTypeFunctions() const;
PythonRpcHandler(const PythonRpcHandler&) = delete;
PythonRpcHandler& operator=(const PythonRpcHandler&) = delete;
PythonRpcHandler(PythonRpcHandler&&) = delete;
PythonRpcHandler& operator=(PythonRpcHandler&&) = delete;
private:
void init();
PythonRpcHandler();
~PythonRpcHandler() = default;
// Ref to `torch.distributed.rpc.internal._run_function`.
py::object pyRunFunction_;
// Ref to `torch.distributed.rpc.internal.serialize`.
py::object pySerialize_;
// Ref to `torch.distributed.rpc.internal.deserialize`.
py::object pyDeserialize_;
// Ref to 'torch.distributed.rpc.internal._handle_exception'
py::object pyHandleException_;
// Python functions for RRef proxy
RRefProxyFunctions rrefProxyFunctions_;
// Ref to 'torch.distributed.rpc.api._rref_typeof_on_'
RRefTypeFunctions rrefTypeFunctions_;
// Shared ptr to python compilation unit in jit, it is constructed in python
// side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
// and imported in C++ (see get_python_cu() in
// csrc/jit/python/pybind_utils.h). We import the compilation unit here only
// once for less cost and thread safety.
std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;
// jit type parser to parse type_str back to TypePtr for RRef type
// recovery when pickling and unpickling RRef
std::shared_ptr<jit::ScriptTypeParser> typeParser_;
// Indicates whether or not we have properly initialized the handler.
bool initialized_;
// Lock to protect initialization.
std::mutex init_lock_;
};
} // namespace torch::distributed::rpc
|