1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
|
From 6aea4d90fb1b147e8e244abdfb93153bc06ff6c7 Mon Sep 17 00:00:00 2001
From: Tristan Rice <tristanr@meta.com>
Date: Tue, 1 Apr 2025 23:37:25 +0000
Subject: [PATCH] gloo: use shared Stores (#150230)
Summary:
X-link: https://github.com/facebookincubator/gloo/pull/423
This modifies `connectFullMesh` to take in a shared_ptr<IStore> instead of a reference. This is an API breaking change but fairly easy to work around.
To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo.
This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic.
This change will land simultaneously in PyTorch and Gloo repos.
Test Plan:
```
buck2 test //gloo/... //caffe2/caffe2/contrib/gloo:
```
Differential Revision: D72084111
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150230
Approved by: https://github.com/fduwjj
---
.../distributed/c10d/ProcessGroupGloo.cpp | 19 +++++++++++++++++--
.../distributed/c10d/ProcessGroupGloo.hpp | 4 ++--
2 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
index 345b2741dc97..3c5644eeab68 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp
@@ -785,10 +785,25 @@ ProcessGroupGloo::ProcessGroupGloo(
contexts_.reserve(options_->devices.size());
for (const auto i : c10::irange(options_->devices.size())) {
auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
- auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
+
+#ifdef GLOO_SHARED_STORE
+ auto underlyingStore = store_;
+#else
+ auto& underlyingStore = *store_;
+#endif
+
+ auto store = std::make_shared<::gloo::rendezvous::PrefixStore>(
+ std::to_string(i), underlyingStore);
+
+#ifdef GLOO_SHARED_STORE
+ auto connectStore = store;
+#else
+ auto& connectStore = *store;
+#endif
+
context->setTimeout(options_->timeout);
try {
- context->connectFullMesh(store, options_->devices[i]);
+ context->connectFullMesh(connectStore, options_->devices[i]);
} catch (const std::runtime_error& e) {
auto err = e.what();
// TORCH_CHECK to print the cpp stacktrace.
diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp
index b44cba9f35a4..059ba8a4ee3f 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp
@@ -367,7 +367,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
void enableCollectivesTiming() override;
- const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
+ const std::shared_ptr<::gloo::rendezvous::Store>& _getStore() const {
return store_;
}
@@ -393,7 +393,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
}
protected:
- std::unique_ptr<::gloo::rendezvous::Store> store_;
+ std::shared_ptr<::gloo::rendezvous::Store> store_;
const c10::intrusive_ptr<Options> options_;
// Every Gloo context represents a set of connections to its peers.
|