File: context.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (296 lines) | stat: -rw-r--r-- 9,807 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#include <torch/csrc/distributed/autograd/context/context.h>

#include <functional>

#include <c10/core/StreamGuard.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>

namespace torch {
namespace distributed {
namespace autograd {

using torch::autograd::AccumulateGrad;

DistAutogradContext::DistAutogradContext(int64_t contextId)
    : contextId_(contextId),
      impl_(c10::impl::VirtualGuardImpl{
          at::hasCUDA() ? c10::DeviceType::CUDA : c10::DeviceType::CPU}) {}

int64_t DistAutogradContext::contextId() const {
  return contextId_;
}

std::unordered_set<rpc::worker_id_t> DistAutogradContext::getKnownWorkerIds()
    const {
  std::lock_guard<std::mutex> guard(lock_);
  return knownWorkerIds_;
};

void DistAutogradContext::addKnownWorkerId(const rpc::worker_id_t workerId) {
  std::lock_guard<std::mutex> guard(lock_);
  knownWorkerIds_.insert(workerId);
}

void DistAutogradContext::addSendFunction(
    const std::shared_ptr<SendRpcBackward>& func,
    int64_t autograd_message_id) {
  TORCH_INTERNAL_ASSERT(func != nullptr);

  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      sendAutogradFunctions_.find(autograd_message_id) ==
      sendAutogradFunctions_.end());
  sendAutogradFunctions_.emplace(autograd_message_id, func);
}

void DistAutogradContext::addRecvFunction(
    std::shared_ptr<RecvRpcBackward>& func,
    int64_t autograd_message_id) {
  TORCH_INTERNAL_ASSERT(func != nullptr);

  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      recvAutogradFunctions_.find(autograd_message_id) ==
      recvAutogradFunctions_.end());
  recvAutogradFunctions_.emplace(autograd_message_id, func);
}

std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
DistAutogradContext::sendFunctions() const {
  std::lock_guard<std::mutex> guard(lock_);
  return sendAutogradFunctions_;
}

std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
DistAutogradContext::recvFunctions() const {
  std::lock_guard<std::mutex> guard(lock_);
  return recvAutogradFunctions_;
}

void DistAutogradContext::accumulateGrad(
    const torch::autograd::Variable& variable,
    const torch::Tensor& grad,
    size_t num_expected_refs) {
  TORCH_INTERNAL_ASSERT(grad.defined());
  TORCH_INTERNAL_ASSERT(variable.requires_grad());

  std::lock_guard<std::mutex> guard(lock_);
  auto it = accumulatedGrads_.find(variable);
  at::Tensor old_grad;
  if (it != accumulatedGrads_.end()) {
    // Accumulate multiple grads on the same variable.
    old_grad = it->value();
  }

  // Gradients are computed using the forward streams. Local autograd
  // engine uses AccumulateGrad function to retrieve and apply forward
  // stream during the backward computation. In distributed autograd,
  // we directly call AccumulateGrad::accumulateGrad, and skip the
  // CUDA stream restoration from autograd function. Hence, we manually
  // call it here to get the streams correct.
  auto forward_stream =
      torch::autograd::impl::grad_accumulator(variable)->stream(
          grad.device().type());
  c10::OptionalStreamGuard stream_guard(forward_stream);

  // No higher order gradients supported in distributed autograd.
  AutoGradMode grad_mode(false);

  at::Tensor new_grad = AccumulateGrad::callHooks(variable, grad);

  // TODO: Need to bump 'num_expected_refs' here when we support post_hooks for
  // distributed autograd as part of
  // https://github.com/pytorch/pytorch/issues/33482
  AccumulateGrad::accumulateGrad(
      variable,
      old_grad,
      new_grad,
      // Add +1 here since we can't std::move(grad) when call
      // AccumulateGrad::callHooks, since it is a const ref, and that incurs a
      // refcount bump for the new_grad.
      num_expected_refs + 1,
      [this, &variable](at::Tensor&& grad_update) {
        auto device = grad_update.device();
        accumulatedGrads_.insert(variable, std::move(grad_update));
        recordGradEvent(device);
      });
}

std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
    retrieveGraphTask() {
  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(graphTask_);
  return graphTask_;
}

void DistAutogradContext::setGraphTask(
    std::shared_ptr<torch::autograd::GraphTask> graphTask) {
  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      !graphTask_,
      "Cannot set GraphTask multiple times for the same autograd context");
  graphTask_ = std::move(graphTask);
}

void DistAutogradContext::resetGraphTask() {
  std::lock_guard<std::mutex> guard(lock_);
  graphTask_ = nullptr;
}

void DistAutogradContext::addOutstandingRpc(
    const c10::intrusive_ptr<rpc::JitFuture>& jitFuture) {
  jitFuture->addCallback([this](rpc::JitFuture& future) {
    if (future.hasError()) {
      // If we have an error, let the local autograd engine know about it.
      std::unique_lock<std::mutex> lock(lock_);
      if (graphTask_) {
        graphTask_->set_exception_without_signal(nullptr);
        lock.unlock();
        if (!graphTask_->future_completed_.exchange(true)) {
          graphTask_->future_result_->setErrorIfNeeded(future.exception_ptr());
        }
      } else {
        LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
                     << future.tryRetrieveErrorMessage();
      }
    }
  });
  std::lock_guard<std::mutex> guard(lock_);
  outStandingRpcs_.push_back(jitFuture);
}

void DistAutogradContext::clearOutstandingRpcs() {
  std::unique_lock<std::mutex> lock(lock_);
  outStandingRpcs_.clear();
}

void DistAutogradContext::recordGradEvent(c10::Device device) {
  if (device.is_cuda()) {
    auto iter = gradReadyEvents_.find(device);
    if (iter == gradReadyEvents_.end()) {
      c10::Event event(device.type());
      event.record(impl_.getStream(event.device()));
      gradReadyEvents_.emplace(
          std::piecewise_construct,
          std::forward_as_tuple(device),
          std::forward_as_tuple(std::move(event)));
    } else {
      iter->second.record(impl_.getStream(device));
    }
  }
}

c10::intrusive_ptr<c10::ivalue::Future> DistAutogradContext::
    clearAndWaitForOutstandingRpcsAsync() {
  std::unique_lock<std::mutex> lock(lock_);
  auto outStandingRpcs = std::move(outStandingRpcs_);
  lock.unlock();

  struct State {
    explicit State(int32_t count)
        : future(
              c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get())),
          remaining(count) {}
    c10::intrusive_ptr<c10::ivalue::Future> future;
    std::atomic<int32_t> remaining;
    std::atomic<bool> alreadySentError{false};
  };
  auto state = std::make_shared<State>(outStandingRpcs.size());
  if (outStandingRpcs.empty()) {
    state->future->markCompleted(c10::IValue());
  } else {
    for (auto& rpc : outStandingRpcs) {
      rpc->addCallback([state](rpc::JitFuture& future) {
        if (future.hasError()) {
          // If there's an error, we want to setError() on the future,
          // unless another error has already been sent - use a CAS to
          // guard.
          //
          // Don't decrement num remaining here! (We don't need to, since
          // memory handling is separate). If we simply don't decrement on
          // errors, reaching 0 means that there were no errors - and hence,
          // we can just markCompleted() without any other checking there.
          bool expectedAlreadySent = false;
          if (state->alreadySentError.compare_exchange_strong(
                  expectedAlreadySent, true)) {
            state->future->setError(future.exception_ptr());
          }
          return;
        }

        if (--state->remaining == 0) {
          state->future->markCompleted(c10::IValue());
        }
      });
    }
  }
  return state->future;
}

std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
    int64_t autograd_message_id) {
  std::lock_guard<std::mutex> guard(lock_);
  auto it = sendAutogradFunctions_.find(autograd_message_id);
  TORCH_CHECK(
      it != sendAutogradFunctions_.end(),
      "Could not find send function for autograd message id: ",
      autograd_message_id);
  return it->second;
}

const c10::Dict<torch::Tensor, torch::Tensor> DistAutogradContext::
    getGradients() const {
  std::lock_guard<std::mutex> guard(lock_);
  // block current streams before accessing gradients to make sure that
  // gradient computations are finished before use.
  for (auto& entry : gradReadyEvents_) {
    auto& event = entry.second;
    event.block(impl_.getStream(event.device()));
  }
  return accumulatedGrads_;
}

void DistAutogradContext::runGradCallbackForVariable(
    const torch::autograd::Variable& variable,
    GradCallback&& cb) {
  torch::Tensor grad;
  {
    std::lock_guard<std::mutex> guard(lock_);
    auto it = accumulatedGrads_.find(variable);
    TORCH_INTERNAL_ASSERT(
        it != accumulatedGrads_.end(),
        "The grad for the variable should exist in dist_autograd context.");
    grad = it->value();
  }
  if (cb(grad)) {
    std::lock_guard<std::mutex> guard(lock_);
    auto device = grad.device();
    // Needs to update the grad in the map.
    accumulatedGrads_.insert_or_assign(variable, std::move(grad));
    recordGradEvent(device);
  }
}

namespace {
thread_local ContextPtr tl_context_ptr;
} // namespace

ThreadLocalDistAutogradContext::ThreadLocalDistAutogradContext(
    ContextPtr&& new_context)
    : prev_context_ptr_(std::move(tl_context_ptr)) {
  tl_context_ptr = std::move(new_context);
}

ThreadLocalDistAutogradContext::~ThreadLocalDistAutogradContext() {
  tl_context_ptr = std::move(prev_context_ptr_);
}

// static
ContextPtr ThreadLocalDistAutogradContext::getContextPtr() {
  return tl_context_ptr;
}

} // namespace autograd
} // namespace distributed
} // namespace torch