File: e2e_test_base.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (172 lines) | stat: -rw-r--r-- 5,578 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <gtest/gtest.h>

#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/runtime/operator.h>

namespace torch {
namespace distributed {
namespace rpc {

using torch::distributed::autograd::DistAutogradContainer;
using torch::distributed::autograd::DistAutogradContext;

DistAutogradContainer* getDistAutogradContainer();

class TestE2EBase : public ::testing::Test {
 protected:
  void SetUp() override {
    // Setup distributed autograd.
    autogradContainer = getDistAutogradContainer();

    // Setup server store.
    c10d::TCPStoreOptions opts{
        /* port */ 0,
        /* isServer */ true,
        numWorkers,
        /* waitWorkers */ true,
        /* timeout */ std::chrono::seconds(10)};

    store = c10::make_intrusive<c10d::TCPStore>(serverAddress, opts);

    buildRpcAgent();

    rpcAgentPostProcessing();
  }

  void rpcAgentPostProcessing() {
    RpcAgent::setCurrentRpcAgent(rpcAgent);
    std::shared_ptr<TypeResolver> typeResolver =
        std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
          // For Dict that is used for device map.
          auto pos = qn.name().find("Dict");
          if (pos != std::string::npos) {
            return c10::StrongTypePtr(
                nullptr,
                c10::DictType::create(
                    c10::StringType::get(), c10::StringType::get()));
          }
          return c10::StrongTypePtr(
              nullptr, c10::TensorType::create(at::Tensor()));
        });
    rpcAgent->setTypeResolver(typeResolver);
    rpcAgent->start();
  }

  void TearDown() override {
    rpcAgent->join();
    rpcAgent->shutdown();
    RpcAgent::setCurrentRpcAgent(nullptr);
  }

  c10::intrusive_ptr<OwnerRRef> createRemoteRRef(
      at::Tensor t1,
      at::Tensor t2,
      std::shared_ptr<torch::jit::Operator> op) {
    auto& ctx = RRefContext::getInstance();
    auto ownerRRef = ctx.createOwnerRRef(c10::TensorType::create(t1));
    // prevent this owner RRef being deleted due to other forks
    ctx.addSelfAsFork(ownerRRef);

    ScriptRemoteCall scriptRemoteCall(
        op, {t1, t2, 1}, ownerRRef->rrefId(), ownerRRef->rrefId());
    auto jitFuture = autograd::sendMessageWithAutograd(
        *rpcAgent,
        rpcAgent->getWorkerInfo("worker"),
        std::move(scriptRemoteCall).toMessage(),
        false);

    ownerRRef->registerOwnerCreationFuture(jitFuture);

    // Builtin operators does not return py::object, and hence does not require
    // GIL for destructing the potentially deleted OwerRRef.
    jitFuture->addCallback(
        [ownerRRefId = ownerRRef->rrefId()](JitFuture& jitFuture) {
          callback::finishCreatingOwnerRRef(jitFuture, ownerRRefId);
        });
    return ownerRRef;
  }

  at::Tensor remoteAdd(
      at::Tensor t1,
      at::Tensor t2,
      std::shared_ptr<torch::jit::Operator> op) {
    ScriptCall scriptCall(op, {t1, t2, /* alpha */ 1});

    // Send the RPC and return result.
    auto response = autograd::sendMessageWithAutograd(
        *rpcAgent,
        rpcAgent->getWorkerInfo("worker"),
        std::move(scriptCall).toMessage());
    response->waitAndThrow();

    MessageType messageType = MessageType::FORWARD_AUTOGRAD_RESP;
    auto wrappedResponse = deserializeResponse(
        std::move(*response->value().toCustomClass<Message>()), messageType);
    return static_cast<ScriptResp&>(*wrappedResponse).value().toTensor();
  }

  virtual void buildRpcAgent() = 0;

  class AutogradContextGuard {
   public:
    explicit AutogradContextGuard()
        : context(DistAutogradContainer::getInstance().newContext()) {}

    ~AutogradContextGuard() {
      DistAutogradContainer::getInstance().releaseContext(context->contextId());
    }

   private:
    std::shared_ptr<DistAutogradContext> context;
  };

  void runTrainingLoop() {
    auto options = at::TensorOptions().requires_grad(true);
    auto t1 = torch::ones({3, 3}, options);
    auto t2 = torch::ones({3, 3}, options);

    c10::OperatorName full_name("aten::add", "Tensor");
    auto matchedOp = torch::jit::findOperatorFor(full_name);
    ASSERT_TRUE(matchedOp);

    for (size_t i = 0; i < numIters; i++) {
      // Create the autograd context guard.
      AutogradContextGuard guard;

      // Multiple RPCs within one autograd context for the forward pass.
      auto result = remoteAdd(t1, t2, matchedOp);
      for (size_t j = 0; j < 5; j++) {
        result = remoteAdd(t1, result, matchedOp);
      }

      auto rref = createRemoteRRef(t1, result, matchedOp);
      result = rref->getValue().toTensor();

      // Run backward pass now.
      autograd::DistEngine::getInstance().execute(
          DistAutogradContainer::currentContextId(),
          {torch::sum(result)},
          /* retainGraph */ false);
    }
  }

  DistAutogradContainer* autogradContainer;
  std::shared_ptr<RpcAgent> rpcAgent;
  static const size_t numIters;
  static const size_t numWorkers;
  c10::intrusive_ptr<c10d::Store> store;
  static const char* serverAddress;
};

} // namespace rpc
} // namespace distributed
} // namespace torch