File: RankLocal.hpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (73 lines) | stat: -rw-r--r-- 2,287 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

#pragma once

#include <shared_mutex>

#include <torch/csrc/autograd/function.h>

namespace c10d {

// `RankLocal` maintains a unique instance of T for each non-autograd thread.
// For non-autograd threads, `RankLocal<T>::get()` functions similar to
// thread_local. For autograd threads, `RankLocal<T>::get()` returns the
// instance of T corresponding to the enqueuing non-autograd thread. The
// mechanism allows for rank-specific context shared between forward and
// backward. It works for both the one-rank-per-process and one-rank-per-thread
// scenarios.
//
// NOTE: RankLocal doesn't make the underlying objects thread-safe.
template <typename T>
class RankLocal {
 public:
  RankLocal(const RankLocal&) = delete;
  RankLocal& operator=(const RankLocal&) = delete;

  static T& get() {
    // Fast path: non-autograd threads can simply return
    // the object reference cached in TLS.
    if (cached_ != nullptr) {
      return *cached_;
    }
    const auto node = torch::autograd::get_current_node();
    auto fwd_thread_id = node == nullptr ? at::RecordFunction::currentThreadId()
                                         : node->thread_id();
    // Optimistically acquire the read lock first, since most likely we are in
    // an autograd thread and the object has already been constructed.
    {
      std::shared_lock read_lock(lock_);
      auto it = thread_id_to_rank_local_.find(fwd_thread_id);
      if (it != thread_id_to_rank_local_.end()) {
        // Cache for non-autograd threads
        if (node == nullptr) {
          cached_ = &it->second;
        }
        return it->second;
      }
    }

    std::unique_lock write_lock(lock_);
    auto [it, _] = thread_id_to_rank_local_.try_emplace(fwd_thread_id);
    // Cache for non-autograd threads
    if (node == nullptr) {
      cached_ = &it->second;
    }
    return it->second;
  }

 private:
  RankLocal() = default;
  thread_local static T* cached_;
  static std::unordered_map<uint64_t, T> thread_id_to_rank_local_;
  static std::shared_mutex lock_;
};

template <typename T>
thread_local T* RankLocal<T>::cached_ = nullptr;

template <typename T>
std::unordered_map<uint64_t, T> RankLocal<T>::thread_id_to_rank_local_;

template <typename T>
std::shared_mutex RankLocal<T>::lock_;

} // namespace c10d