1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
#pragma once
#include <chrono>
#include <string>
#include "caffe2/core/db.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/queue/blobs_queue.h"
namespace caffe2 {
namespace db {
namespace {
const std::string& GetStringFromBlob(Blob* blob) {
if (blob->template IsType<string>()) {
return blob->template Get<string>();
} else if (blob->template IsType<Tensor>()) {
return *blob->template Get<Tensor>().template data<string>();
} else {
CAFFE_THROW("Unsupported Blob type");
}
}
} // namespace
class BlobsQueueDBCursor : public Cursor {
public:
explicit BlobsQueueDBCursor(
std::shared_ptr<BlobsQueue> queue,
int key_blob_index,
int value_blob_index,
float timeout_secs)
: queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs),
inited_(false),
valid_(false) {
LOG(INFO) << "BlobsQueueDBCursor constructed";
CAFFE_ENFORCE(queue_ != nullptr, "queue is null");
CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0");
}
virtual ~BlobsQueueDBCursor() {}
void Seek(const string& /* unused */) override {
CAFFE_THROW("Seek is not supported.");
}
bool SupportsSeek() override {
return false;
}
void SeekToFirst() override {
// not applicable
}
void Next() override {
unique_ptr<Blob> blob = make_unique<Blob>();
vector<Blob*> blob_vector{blob.get()};
auto success = queue_->blockingRead(blob_vector, timeout_secs_);
if (!success) {
LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed";
valid_ = false;
return;
}
if (key_blob_index_ >= 0) {
key_ = GetStringFromBlob(blob_vector[key_blob_index_]);
}
value_ = GetStringFromBlob(blob_vector[value_blob_index_]);
valid_ = true;
}
string key() override {
if (!inited_) {
Next();
inited_ = true;
}
return key_;
}
string value() override {
if (!inited_) {
Next();
inited_ = true;
}
return value_;
}
bool Valid() override {
return valid_;
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
bool inited_;
string key_;
string value_;
bool valid_;
};
class BlobsQueueDB : public DB {
public:
BlobsQueueDB(
const string& source,
Mode mode,
std::shared_ptr<BlobsQueue> queue,
int key_blob_index = -1,
int value_blob_index = 0,
float timeout_secs = 0.0)
: DB(source, mode),
queue_(queue),
key_blob_index_(key_blob_index),
value_blob_index_(value_blob_index),
timeout_secs_(timeout_secs) {
LOG(INFO) << "BlobsQueueDB constructed";
}
virtual ~BlobsQueueDB() {
Close();
}
void Close() override {}
unique_ptr<Cursor> NewCursor() override {
return make_unique<BlobsQueueDBCursor>(
queue_, key_blob_index_, value_blob_index_, timeout_secs_);
}
unique_ptr<Transaction> NewTransaction() override {
CAFFE_THROW("Not implemented.");
}
private:
std::shared_ptr<BlobsQueue> queue_;
int key_blob_index_;
int value_blob_index_;
float timeout_secs_;
};
} // namespace db
} // namespace caffe2
|