File: net_async_task_future.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (111 lines) | stat: -rw-r--r-- 2,853 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include "caffe2/core/net_async_task_future.h"

#include "c10/util/Logging.h"
#include "caffe2/core/common.h"

namespace caffe2 {

AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}

AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
    : completed_(false), failed_(false) {
  if (futures.size() > 1) {
    parent_counter_ = std::make_unique<ParentCounter>(futures.size());
    for (auto future : futures) {
      future->SetCallback([this](const AsyncTaskFuture* f) {
        if (f->IsFailed()) {
          std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
          if (parent_counter_->parent_failed) {
            parent_counter_->err_msg += ", " + f->ErrorMessage();
          } else {
            parent_counter_->parent_failed = true;
            parent_counter_->err_msg = f->ErrorMessage();
          }
        }
        int count = --parent_counter_->parent_count;
        if (count == 0) {
          // thread safe to use parent_counter here
          if (!parent_counter_->parent_failed) {
            SetCompleted();
          } else {
            SetCompleted(parent_counter_->err_msg.c_str());
          }
        }
      });
    }
  } else {
    CAFFE_ENFORCE_EQ(futures.size(), (size_t)1);
    auto future = futures.back();
    future->SetCallback([this](const AsyncTaskFuture* f) {
      if (!f->IsFailed()) {
        SetCompleted();
      } else {
        SetCompleted(f->ErrorMessage().c_str());
      }
    });
  }
}

bool AsyncTaskFuture::IsCompleted() const {
  return completed_;
}

bool AsyncTaskFuture::IsFailed() const {
  return failed_;
}

std::string AsyncTaskFuture::ErrorMessage() const {
  return err_msg_;
}

void AsyncTaskFuture::Wait() const {
  std::unique_lock<std::mutex> lock(mutex_);
  while (!completed_) {
    cv_completed_.wait(lock);
  }
}

void AsyncTaskFuture::SetCallback(
    std::function<void(const AsyncTaskFuture*)> callback) {
  std::unique_lock<std::mutex> lock(mutex_);

  callbacks_.push_back(callback);
  if (completed_) {
    callback(this);
  }
}

void AsyncTaskFuture::SetCompleted(const char* err_msg) {
  std::unique_lock<std::mutex> lock(mutex_);

  CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
  completed_ = true;

  if (err_msg) {
    failed_ = true;
    err_msg_ = err_msg;
  }

  for (auto& callback : callbacks_) {
    callback(this);
  }

  cv_completed_.notify_all();
}

// ResetState is called on a completed future,
// does not reset callbacks to keep task graph structure
void AsyncTaskFuture::ResetState() {
  std::unique_lock<std::mutex> lock(mutex_);
  if (parent_counter_) {
    parent_counter_->Reset();
  }
  completed_ = false;
  failed_ = false;
  err_msg_ = "";
}

// NOLINTNEXTLINE(modernize-use-equals-default)
AsyncTaskFuture::~AsyncTaskFuture() {}

} // namespace caffe2