File: HashStore.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (173 lines) | stat: -rw-r--r-- 4,563 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include <torch/csrc/distributed/c10d/HashStore.hpp>

#include <unistd.h>
#include <cstdint>

#include <chrono>

#include <c10/util/Exception.h>

namespace c10d {

void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
  std::unique_lock<std::mutex> lock(m_);
  map_[key] = data;
  cv_.notify_all();
}

std::vector<uint8_t> HashStore::compareSet(
    const std::string& key,
    const std::vector<uint8_t>& expectedValue,
    const std::vector<uint8_t>& desiredValue) {
  std::unique_lock<std::mutex> lock(m_);
  auto it = map_.find(key);
  if ((it == map_.end() && expectedValue.empty()) ||
      (it != map_.end() && it->second == expectedValue)) {
    // if the key does not exist and currentValue arg is empty or
    // the key does exist and current value is what is expected, then set it
    map_[key] = desiredValue;
    cv_.notify_all();
    return desiredValue;
  } else if (it == map_.end()) {
    // if the key does not exist
    return expectedValue;
  }
  // key exists but current value is not expected
  return it->second;
}

std::vector<uint8_t> HashStore::get(const std::string& key) {
  std::unique_lock<std::mutex> lock(m_);
  auto it = map_.find(key);
  if (it != map_.end()) {
    return it->second;
  }
  // Slow path: wait up to any timeout_.
  auto pred = [&]() { return map_.find(key) != map_.end(); };
  if (timeout_ == kNoTimeout) {
    cv_.wait(lock, pred);
  } else {
    if (!cv_.wait_for(lock, timeout_, pred)) {
      C10_THROW_ERROR(DistStoreError, "Wait timeout");
    }
  }
  return map_[key];
}

void HashStore::wait(
    const std::vector<std::string>& keys,
    const std::chrono::milliseconds& timeout) {
  const auto end = std::chrono::steady_clock::now() + timeout;
  auto pred = [&]() {
    auto done = true;
    for (const auto& key : keys) {
      if (map_.find(key) == map_.end()) {
        done = false;
        break;
      }
    }
    return done;
  };

  std::unique_lock<std::mutex> lock(m_);
  if (timeout == kNoTimeout) {
    cv_.wait(lock, pred);
  } else {
    if (!cv_.wait_until(lock, end, pred)) {
      C10_THROW_ERROR(DistStoreError, "Wait timeout");
    }
  }
}

int64_t HashStore::add(const std::string& key, int64_t i) {
  std::unique_lock<std::mutex> lock(m_);
  const auto& value = map_[key];
  int64_t ti = i;
  if (!value.empty()) {
    auto buf = reinterpret_cast<const char*>(value.data());
    auto len = value.size();
    ti += std::stoll(std::string(buf, len));
  }

  auto str = std::to_string(ti);
  const uint8_t* strB = reinterpret_cast<const uint8_t*>(str.c_str());
  map_[key] = std::vector<uint8_t>(strB, strB + str.size());
  return ti;
}

int64_t HashStore::getNumKeys() {
  std::unique_lock<std::mutex> lock(m_);
  return static_cast<int64_t>(map_.size());
}

bool HashStore::deleteKey(const std::string& key) {
  std::unique_lock<std::mutex> lock(m_);
  auto numDeleted = map_.erase(key);
  return (numDeleted == 1);
}

bool HashStore::check(const std::vector<std::string>& keys) {
  std::unique_lock<std::mutex> lock(m_);
  for (const auto& key : keys) {
    if (map_.find(key) == map_.end()) {
      return false;
    }
  }
  return true;
}

void HashStore::append(
    const std::string& key,
    const std::vector<uint8_t>& value) {
  std::unique_lock<std::mutex> lock(m_);
  auto it = map_.find(key);
  if (it == map_.end()) {
    map_[key] = value;
  } else {
    it->second.insert(it->second.end(), value.begin(), value.end());
  }
  cv_.notify_all();
}

std::vector<std::vector<uint8_t>> HashStore::multiGet(
    const std::vector<std::string>& keys) {
  std::unique_lock<std::mutex> lock(m_);
  auto deadline = std::chrono::steady_clock::now() + timeout_;
  std::vector<std::vector<uint8_t>> res;
  res.reserve(keys.size());

  for (auto& key : keys) {
    auto it = map_.find(key);
    if (it != map_.end()) {
      res.emplace_back(it->second);
    } else {
      auto pred = [&]() { return map_.find(key) != map_.end(); };
      if (timeout_ == kNoTimeout) {
        cv_.wait(lock, pred);
      } else {
        if (!cv_.wait_until(lock, deadline, pred)) {
          C10_THROW_ERROR(DistStoreError, "Wait timeout");
        }
      }
      res.emplace_back(map_[key]);
    }
  }
  return res;
}

void HashStore::multiSet(
    const std::vector<std::string>& keys,
    const std::vector<std::vector<uint8_t>>& values) {
  std::unique_lock<std::mutex> lock(m_);

  for (auto i : ::c10::irange(keys.size())) {
    map_[keys[i]] = values[i];
  }
  cv_.notify_all();
}

bool HashStore::hasExtendedApi() const {
  return true;
}

} // namespace c10d