File: python_rpc_handler.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (133 lines) | stat: -rw-r--r-- 5,004 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#pragma once

#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/utils/pybind.h>

namespace torch {
namespace distributed {
namespace rpc {

// Singleton class provides interface to execute python UDF remote call
// and deserialize the returned results by running python function
// in internal_rpc_utilities.
// The singleton object is constructed at first when RPC agent is
// constructed, where the python function in
// torch/distributed/internal_rpc_utils.py are imported only once.
class PYBIND11_EXPORT PythonRpcHandler {
 public:
  struct RRefProxyFunctions {
    py::object rrefProxyCtor_;
    py::object rpcSync_;
    py::object rpcAsync_;
    py::object remote_;
  };

  struct RRefTypeFunctions {
    py::object onOwner_;
    py::object onUser_;
  };

  static PythonRpcHandler& getInstance();

  // Run a pickled Python UDF and return the result py::object
  py::object runPythonUdf(const py::object& pythonUdf);

  // Serialized a py::object into a string
  SerializedPyObj serialize(const py::object& obj);

  // Deserialize a string into a py::object
  py::object deserialize(const SerializedPyObj& serializedObj);

  // Check if obj is RemoteException, then throw it
  void handleException(const py::object& obj);
  // Alternative if the caller is already holding the GIL.
  void handleExceptionGILHeld(const py::object& obj);
  // Check if obj is an RemoteException instance.
  bool isRemoteException(const py::object& obj);

  // Explicitly clean up py::objects to avoid segment faults when
  // py::objects with CPython are cleaned up later at program exit
  // See similar issues reported https://github.com/pybind/pybind11/issues/1598
  // and https://github.com/pybind/pybind11/issues/1493
  // Our local tests also caught this segment faults if py::objects are cleaned
  // up at program exit. The explanation is: CPython cleans up most critical
  // utilities before cleaning up PythonRpcHandler singleton, so when
  // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it
  // will crash.
  // The solution is to clean up py::objects earlier when Rpc agent join().
  // Be note that py::objects can not be cleaned up when Rpc agent is destroyed
  // as well, as Rpc agent is global variable and it will have same issue as
  // PythonRpcHandler.
  void cleanup();

  std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();

  // Parse the string to recover the jit_type, this is used for RRef python
  // pickling/unpickling type recovery. The type string inference rule is as
  // follows:
  // 1. first try to parse if this is primitive types.
  //    i.e. TensorType, IntType, PyObjectType, etc.
  // 2. if not primitive type, we query the python_cu to see if it is a
  //    class type or interface type registered in python
  // We use a ScriptTypeParser instance with custom PythonTypeResolver
  // to resolve types according to the above rules.
  TypePtr parseTypeFromStr(const std::string& typeStr);

  // Return a set of Python functions for RRef helpers.
  const RRefProxyFunctions& getRRefProxyFunctions() const;

  // Return a set of Python functions to retrieve the type of the object
  // referenced by a given RRef.
  const RRefTypeFunctions& getRRefTypeFunctions() const;

  PythonRpcHandler(const PythonRpcHandler&) = delete;
  PythonRpcHandler& operator=(const PythonRpcHandler&) = delete;
  PythonRpcHandler(PythonRpcHandler&&) = delete;
  PythonRpcHandler& operator=(PythonRpcHandler&&) = delete;

 private:
  void init();
  PythonRpcHandler();
  ~PythonRpcHandler() = default;

  // Ref to `torch.distributed.rpc.internal._run_function`.
  py::object pyRunFunction_;

  // Ref to `torch.distributed.rpc.internal.serialize`.
  py::object pySerialize_;

  // Ref to `torch.distributed.rpc.internal.deserialize`.
  py::object pyDeserialize_;

  // Ref to 'torch.distributed.rpc.internal._handle_exception'
  py::object pyHandleException_;

  // Python functions for RRef proxy
  RRefProxyFunctions rrefProxyFunctions_;

  // Ref to 'torch.distributed.rpc.api._rref_typeof_on_'
  RRefTypeFunctions rrefTypeFunctions_;

  // Shared ptr to python compilation unit in jit, it is constructed in python
  // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
  // and imported in C++ (see get_python_cu() in
  // csrc/jit/python/pybind_utils.h). We import the compilation unit here only
  // once for less cost and thread safety.
  std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;

  // jit type parser to parse type_str back to TypePtr for RRef type
  // recovery when pickling and unpickling RRef
  std::shared_ptr<jit::ScriptTypeParser> typeParser_;

  // Indicates whether or not we have properly initialized the handler.
  bool initialized_;

  // Lock to protect initialization.
  std::mutex init_lock_;
};

} // namespace rpc
} // namespace distributed
} // namespace torch