1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
#include <c10/util/Exception.h>
#include <fmt/format.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
#include <chrono>
#include <exception>
#include <vector>
namespace {
std::string getRankKey(const std::string& key, int rank) {
return fmt::format("{}/{}", key, rank);
}
} // namespace
namespace c10d {
StoreCollectives::StoreCollectives(
c10::intrusive_ptr<::c10d::Store> store,
int rank,
int worldSize)
: store_(std::move(store)), rank_(rank), worldSize_(worldSize) {}
void StoreCollectives::barrier(
const std::string& key,
std::chrono::milliseconds timeout,
bool blocking) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto num_members_key = fmt::format("{}/num_members", key);
auto last_members_key = fmt::format("{}/last_members", key);
auto idx = store_->add(num_members_key, 1);
store_->set(getRankKey(key, rank_), "joined");
if (idx == worldSize_) {
store_->set(last_members_key, "<val_ignored>");
} else if (blocking) {
try {
store_->wait({last_members_key});
} catch (const std::exception& e) {
std::string msg = "barrier failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
}
}
void StoreCollectives::broadcastSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
store_->set(key, data);
}
std::vector<uint8_t> StoreCollectives::broadcastRecv(
const std::string& key,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
return store_->get(key);
}
void StoreCollectives::gatherSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto rank_key = getRankKey(key, rank_);
store_->set(rank_key, data);
}
std::vector<std::vector<uint8_t>> StoreCollectives::gatherRecv(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
std::vector<std::vector<uint8_t>> results;
results.reserve(worldSize_);
try {
results = store_->multiGet(keys);
} catch (const std::exception& e) {
std::string msg = "gather failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
// insert local data
results.insert(results.begin() + rank_, data);
return results;
}
std::vector<uint8_t> StoreCollectives::scatterSend(
const std::string& key,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
auto local = data.at(rank_);
std::vector<std::vector<uint8_t>> toSend{data};
toSend.erase(toSend.begin() + rank_);
store_->multiSet(keys, toSend);
return local;
}
std::vector<uint8_t> StoreCollectives::scatterRecv(
const std::string& key,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto rank_key = getRankKey(key, rank_);
return store_->get(rank_key);
}
std::vector<std::vector<uint8_t>> StoreCollectives::allGather(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
auto localKey = getRankKey(key, rank_);
store_->set(localKey, data);
std::vector<std::string> keys;
keys.reserve(worldSize_);
for (int i = 0; i < worldSize_; i++) {
auto rank_key = getRankKey(key, i);
keys.emplace_back(rank_key);
}
try {
return store_->multiGet(keys);
} catch (const std::exception& e) {
std::string msg = "all_gather failed -- missing ranks: ";
for (int i = 0; i < worldSize_; i++) {
if (i == rank_) {
continue;
}
auto rank_key = getRankKey(key, i);
if (!store_->check({rank_key})) {
msg += fmt::format("{}, ", i);
}
}
throw std::runtime_error(msg + e.what());
}
}
int64_t StoreCollectives::allSum(
const std::string& key,
int64_t value,
std::chrono::milliseconds timeout) {
enforceUnique(key);
StoreTimeoutGuard g{*store_, timeout};
store_->add(key, value);
barrier(key + "/barrier", timeout);
return store_->add(key, 0);
}
void StoreCollectives::enforceUnique(const std::string& key) {
auto it = seenKeys_.find(key);
TORCH_INTERNAL_ASSERT(
it == seenKeys_.end(), "Key ", key, " has already been used.");
seenKeys_.emplace(key);
}
} // namespace c10d
|