File: DMAConnectivity.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (94 lines) | stat: -rw-r--r-- 2,528 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
#include <utility>

namespace {

std::string get_detector_key(
    c10::DeviceType device_type,
    const std::string& connection_type) {
  std::ostringstream oss;
  oss << device_type << "/" << connection_type;
  return oss.str();
}

class DetectorMap {
 public:
  DetectorMap(const DetectorMap&) = delete;
  DetectorMap& operator=(const DetectorMap&) = delete;
  static DetectorMap& get() {
    static DetectorMap instance;
    return instance;
  }

  void register_detector(
      c10::DeviceType device_type,
      const std::string& connection_type,
      c10::intrusive_ptr<c10d::DMAConnectivityDetector> detector) {
    auto key = get_detector_key(device_type, connection_type);
    detector_map_[key] = std::move(detector);
  }

  c10::intrusive_ptr<c10d::DMAConnectivity> detect(
      c10::DeviceType device_type,
      const std::string& connection_type) {
    auto key = get_detector_key(device_type, connection_type);
    {
      auto it = cached_.find(key);
      if (it != cached_.end()) {
        return it->second;
      }
    }

    auto it = detector_map_.find(key);
    TORCH_CHECK(
        it != detector_map_.end(),
        "DMA connectivity detector for ",
        device_type,
        " over ",
        connection_type,
        " is not available");
    auto detector = it->second;
    auto connectivity = detector->detect();
    cached_[key] = connectivity;
    return connectivity;
  }

 private:
  DetectorMap() = default;

  std::unordered_map<
      std::string,
      c10::intrusive_ptr<c10d::DMAConnectivityDetector>>
      detector_map_;

  std::unordered_map<std::string, c10::intrusive_ptr<c10d::DMAConnectivity>>
      cached_;
};

} // namespace

namespace c10d {

DMAConnectivity::DMAConnectivity(
    c10::DeviceType device_type,
    std::string connection_type,
    std::vector<std::vector<int>> matrix)
    : device_type(device_type),
      connection_type(std::move(connection_type)),
      matrix(std::move(matrix)) {}

void register_dma_connectivity_detector(
    c10::DeviceType device_type,
    const std::string& connection_type,
    c10::intrusive_ptr<DMAConnectivityDetector> detector) {
  return DetectorMap::get().register_detector(
      device_type, connection_type, std::move(detector));
}

c10::intrusive_ptr<DMAConnectivity> detect_dma_connectivity(
    c10::DeviceType device_type,
    const std::string& connection_type) {
  return DetectorMap::get().detect(device_type, connection_type);
}

} // namespace c10d