File: blobs_queue.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (176 lines) | stat: -rw-r--r-- 5,911 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#include "caffe2/queue/blobs_queue.h"

#include <atomic>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <queue>

#include "caffe2/core/blob_stats.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/timer.h"
#include "caffe2/core/workspace.h"

#include <c10/util/irange.h>

namespace caffe2 {

// Constants for user tracepoints
C10_UNUSED static constexpr int SDT_NONBLOCKING_OP = 0;
C10_UNUSED static constexpr int SDT_BLOCKING_OP = 1;
C10_UNUSED static constexpr uint64_t SDT_TIMEOUT = (uint64_t)-1;
C10_UNUSED static constexpr uint64_t SDT_ABORT = (uint64_t)-2;
C10_UNUSED static constexpr uint64_t SDT_CANCEL = (uint64_t)-3;

BlobsQueue::BlobsQueue(
    Workspace* ws,
    const std::string& queueName,
    size_t capacity,
    size_t numBlobs,
    bool enforceUniqueName,
    const std::vector<std::string>& fieldNames)
    : numBlobs_(numBlobs), name_(queueName), stats_(queueName) {
  if (!fieldNames.empty()) {
    CAFFE_ENFORCE_EQ(
        fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
    stats_.queue_dequeued_bytes.setDetails(fieldNames);
  }
  queue_.reserve(capacity);
  for (size_t i = 0; i < capacity; ++i) {
    std::vector<Blob*> blobs;
    blobs.reserve(numBlobs);
    for (size_t j = 0; j < numBlobs; ++j) {
      const auto blobName = queueName + "_" + to_string(i) + "_" + to_string(j);
      if (enforceUniqueName) {
        CAFFE_ENFORCE(
            !ws->GetBlob(blobName),
            "Queue internal blob already exists: ",
            blobName);
      }
      blobs.push_back(ws->CreateBlob(blobName));
    }
    queue_.push_back(blobs);
  }
  TORCH_DCHECK_EQ(queue_.size(), capacity);
}

bool BlobsQueue::blockingRead(
    const std::vector<Blob*>& inputs,
    float timeout_secs) {
  Timer readTimer;
  auto keeper = this->shared_from_this();
  C10_UNUSED const auto& name = name_.c_str();
  CAFFE_SDT(queue_read_start, name, (void*)this, SDT_BLOCKING_OP);
  std::unique_lock<std::mutex> g(mutex_);
  auto canRead = [this]() {
    CAFFE_ENFORCE_LE(reader_, writer_);
    return reader_ != writer_;
  };
  // Decrease queue balance before reading to indicate queue read pressure
  // is being increased (-ve queue balance indicates more reads than writes)
  CAFFE_EVENT(stats_, queue_balance, -1);
  if (timeout_secs > 0) {
    std::chrono::milliseconds timeout_ms(int(timeout_secs * 1000));
    cv_.wait_for(
        g, timeout_ms, [this, canRead]() { return closing_ || canRead(); });
  } else {
    cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
  }
  if (!canRead()) {
    if (timeout_secs > 0 && !closing_) {
      LOG(ERROR) << "DequeueBlobs timed out in " << timeout_secs << " secs";
      CAFFE_SDT(queue_read_end, name, (void*)this, SDT_TIMEOUT);
    } else {
      CAFFE_SDT(queue_read_end, name, (void*)this, SDT_CANCEL);
    }
    return false;
  }
  DCHECK(canRead());
  auto& result = queue_[reader_ % queue_.size()];
  CAFFE_ENFORCE(inputs.size() >= result.size());
  for (const auto i : c10::irange(result.size())) {
    auto bytes = BlobStat::sizeBytes(*result[i]);
    CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
    using std::swap;
    swap(*(inputs[i]), *(result[i]));
  }
  CAFFE_SDT(queue_read_end, name, (void*)this, writer_ - reader_);
  CAFFE_EVENT(stats_, queue_dequeued_records);
  ++reader_;
  cv_.notify_all();
  CAFFE_EVENT(stats_, read_time_ns, readTimer.NanoSeconds());
  return true;
}

bool BlobsQueue::tryWrite(const std::vector<Blob*>& inputs) {
  Timer writeTimer;
  auto keeper = this->shared_from_this();
  C10_UNUSED const auto& name = name_.c_str();
  CAFFE_SDT(queue_write_start, name, (void*)this, SDT_NONBLOCKING_OP);
  std::unique_lock<std::mutex> g(mutex_);
  if (!canWrite()) {
    CAFFE_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
    return false;
  }
  // Increase queue balance before writing to indicate queue write pressure is
  // being increased (+ve queue balance indicates more writes than reads)
  CAFFE_EVENT(stats_, queue_balance, 1);
  DCHECK(canWrite());
  doWrite(inputs);
  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
  return true;
}

bool BlobsQueue::blockingWrite(const std::vector<Blob*>& inputs) {
  Timer writeTimer;
  auto keeper = this->shared_from_this();
  C10_UNUSED const auto& name = name_.c_str();
  CAFFE_SDT(queue_write_start, name, (void*)this, SDT_BLOCKING_OP);
  std::unique_lock<std::mutex> g(mutex_);
  // Increase queue balance before writing to indicate queue write pressure is
  // being increased (+ve queue balance indicates more writes than reads)
  CAFFE_EVENT(stats_, queue_balance, 1);
  cv_.wait(g, [this]() { return closing_ || canWrite(); });
  if (!canWrite()) {
    CAFFE_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
    return false;
  }
  DCHECK(canWrite());
  doWrite(inputs);
  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
  return true;
}

void BlobsQueue::close() {
  closing_ = true;

  std::lock_guard<std::mutex> g(mutex_);
  cv_.notify_all();
}

bool BlobsQueue::canWrite() {
  // writer is always within [reader, reader + size)
  // we can write if reader is within [reader, reader + size)
  CAFFE_ENFORCE_LE(reader_, writer_);
  CAFFE_ENFORCE_LE(writer_, static_cast<int64_t>(reader_ + queue_.size()));
  // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
  return writer_ != reader_ + queue_.size();
}

void BlobsQueue::doWrite(const std::vector<Blob*>& inputs) {
  auto& result = queue_[writer_ % queue_.size()];
  CAFFE_ENFORCE(inputs.size() >= result.size());
  C10_UNUSED const auto& name = name_.c_str();
  for (const auto i : c10::irange(result.size())) {
    using std::swap;
    swap(*(inputs[i]), *(result[i]));
  }
  CAFFE_SDT(
      queue_write_end, name, (void*)this, reader_ + queue_.size() - writer_);
  ++writer_;
  cv_.notify_all();
}

} // namespace caffe2