File: python_rpc_handler.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (129 lines) | stat: -rw-r--r-- 4,954 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/utils/pybind.h>

namespace torch::distributed::rpc {

// Singleton class provides interface to execute python UDF remote call
// and deserialize the returned results by running python function
// in internal_rpc_utilities.
// The singleton object is constructed at first when RPC agent is
// constructed, where the python function in
// torch/distributed/internal_rpc_utils.py are imported only once.
class PYBIND11_EXPORT PythonRpcHandler {
 public:
  struct RRefProxyFunctions {
    py::object rrefProxyCtor_;
    py::object rpcSync_;
    py::object rpcAsync_;
    py::object remote_;
  };

  struct RRefTypeFunctions {
    py::object onOwner_;
    py::object onUser_;
  };

  static PythonRpcHandler& getInstance();

  // Run a pickled Python UDF and return the result py::object
  py::object runPythonUdf(const py::object& pythonUdf);

  // Serialized a py::object into a string
  SerializedPyObj serialize(const py::object& obj);

  // Deserialize a string into a py::object
  py::object deserialize(const SerializedPyObj& serializedObj);

  // Check if obj is RemoteException, then throw it
  void handleException(const py::object& obj);
  // Alternative if the caller is already holding the GIL.
  void handleExceptionGILHeld(const py::object& obj);
  // Check if obj is an RemoteException instance.
  bool isRemoteException(const py::object& obj);

  // Explicitly clean up py::objects to avoid segment faults when
  // py::objects with CPython are cleaned up later at program exit
  // See similar issues reported https://github.com/pybind/pybind11/issues/1598
  // and https://github.com/pybind/pybind11/issues/1493
  // Our local tests also caught this segment faults if py::objects are cleaned
  // up at program exit. The explanation is: CPython cleans up most critical
  // utilities before cleaning up PythonRpcHandler singleton, so when
  // PythonRpcHandler singleton cleans up py::objects and call dec_ref(), it
  // will crash.
  // The solution is to clean up py::objects earlier when Rpc agent join().
  // Be note that py::objects can not be cleaned up when Rpc agent is destroyed
  // as well, as Rpc agent is global variable and it will have same issue as
  // PythonRpcHandler.
  void cleanup();

  std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit();

  // Parse the string to recover the jit_type, this is used for RRef python
  // pickling/unpickling type recovery. The type string inference rule is as
  // follows:
  // 1. first try to parse if this is primitive types.
  //    i.e. TensorType, IntType, PyObjectType, etc.
  // 2. if not primitive type, we query the python_cu to see if it is a
  //    class type or interface type registered in python
  // We use a ScriptTypeParser instance with custom PythonTypeResolver
  // to resolve types according to the above rules.
  TypePtr parseTypeFromStr(const std::string& typeStr);

  // Return a set of Python functions for RRef helpers.
  const RRefProxyFunctions& getRRefProxyFunctions() const;

  // Return a set of Python functions to retrieve the type of the object
  // referenced by a given RRef.
  const RRefTypeFunctions& getRRefTypeFunctions() const;

  PythonRpcHandler(const PythonRpcHandler&) = delete;
  PythonRpcHandler& operator=(const PythonRpcHandler&) = delete;
  PythonRpcHandler(PythonRpcHandler&&) = delete;
  PythonRpcHandler& operator=(PythonRpcHandler&&) = delete;

 private:
  void init();
  PythonRpcHandler();
  ~PythonRpcHandler() = default;

  // Ref to `torch.distributed.rpc.internal._run_function`.
  py::object pyRunFunction_;

  // Ref to `torch.distributed.rpc.internal.serialize`.
  py::object pySerialize_;

  // Ref to `torch.distributed.rpc.internal.deserialize`.
  py::object pyDeserialize_;

  // Ref to 'torch.distributed.rpc.internal._handle_exception'
  py::object pyHandleException_;

  // Python functions for RRef proxy
  RRefProxyFunctions rrefProxyFunctions_;

  // Ref to 'torch.distributed.rpc.api._rref_typeof_on_'
  RRefTypeFunctions rrefTypeFunctions_;

  // Shared ptr to python compilation unit in jit, it is constructed in python
  // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
  // and imported in C++ (see get_python_cu() in
  // csrc/jit/python/pybind_utils.h). We import the compilation unit here only
  // once for less cost and thread safety.
  std::shared_ptr<torch::jit::CompilationUnit> jitCompilationUnit_;

  // jit type parser to parse type_str back to TypePtr for RRef type
  // recovery when pickling and unpickling RRef
  std::shared_ptr<jit::ScriptTypeParser> typeParser_;

  // Indicates whether or not we have properly initialized the handler.
  bool initialized_;

  // Lock to protect initialization.
  std::mutex init_lock_;
};

} // namespace torch::distributed::rpc