1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/driver_api.h>
#include <cuda_runtime.h>
#include <nvml.h>
namespace {
constexpr int max_nvlinks = 64;
std::string get_bus_id(int device_idx) {
// NOLINTNEXTLINE(*array*)
char bus_id[80];
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx));
snprintf(
bus_id,
sizeof(bus_id),
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
return std::string(bus_id);
}
struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
c10::intrusive_ptr<c10d::DMAConnectivity> detect() override {
int num_devices = 0;
C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices));
std::vector<std::vector<int>> matrix;
matrix.reserve(num_devices);
for (int i = 0; i < num_devices; ++i) {
matrix.emplace_back(num_devices, 0);
}
// Obtain the bus_id for all visible devices
std::unordered_map<std::string, int> bus_id_to_device_idx;
std::vector<std::string> bus_ids;
bus_ids.reserve(num_devices);
for (int i = 0; i < num_devices; ++i) {
auto bus_id = get_bus_id(i);
bus_id_to_device_idx.emplace(bus_id, i);
bus_ids.push_back(std::move(bus_id));
}
static const char* warning_msg =
"PyTorch features that use NVLinkDetector may assume no NVLink presence.";
auto driver_api = c10::cuda::DriverAPI::get();
if (driver_api->nvmlInit_v2_() != NVML_SUCCESS) {
LOG(WARNING)
<< "NVLinkDetector: Failed to initialize NVML via nvmlInit_v2. "
<< warning_msg;
return c10::make_intrusive<c10d::DMAConnectivity>(
c10::DeviceType::CUDA, "nvlink", std::move(matrix));
}
// Obtain the nvml device for all bus_ids
std::vector<nvmlDevice_t> nvml_devices(num_devices, nullptr);
for (int i = 0; i < num_devices; ++i) {
auto res = driver_api->nvmlDeviceGetHandleByPciBusId_v2_(
bus_ids[i].c_str(), &nvml_devices[i]);
if (res != NVML_SUCCESS) {
LOG(WARNING) << "NVLinkDetector: Failed to obtain NVML device via "
<< "nvmlDeviceGetHandleByPciBusId_v2. " << warning_msg;
return c10::make_intrusive<c10d::DMAConnectivity>(
c10::DeviceType::CUDA, "nvlink", std::move(matrix));
}
}
std::vector<int> switch_link_count(num_devices, 0);
for (int i = 0; i < num_devices; ++i) {
for (int link = 0; link < max_nvlinks; ++link) {
nvmlIntNvLinkDeviceType_t deviceType{};
auto ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_(
nvml_devices[i], link, &deviceType);
if (ret != NVML_SUCCESS) {
// We've exhausted the NVLinks connected to this device. This error
// is benign. There doesn't seem to be a reliable way to obtain the
// maximum link value that can be passed to the API. Therefore, we
// simply increment the link value until the API fails or we reach a
// predefined maximum value.
break;
}
// Remote device is GPU
if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) {
nvmlPciInfo_t pciInfo;
auto res = driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_(
nvml_devices[i], link, &pciInfo);
if (res != NVML_SUCCESS) {
LOG(WARNING) << "NVLinkDetector: Failed to obtain NVML device via "
<< "nvmlDeviceGetHandleByPciBusId_v2. " << warning_msg;
return c10::make_intrusive<c10d::DMAConnectivity>(
c10::DeviceType::CUDA, "nvlink", std::move(matrix));
}
auto it = bus_id_to_device_idx.find(pciInfo.busId);
if (it != bus_id_to_device_idx.end()) {
if (i != it->second) {
matrix[i][it->second] += 1;
}
}
// Remote device is NVSwitch
} else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) {
switch_link_count[i] += 1;
}
}
}
// Process NVSwitch connections.
// For simplicity, we assume that all NVSwitches are interconnected.
for (int i = 0; i < num_devices; ++i) {
for (int j = 0; j < num_devices; ++j) {
if (i == j) {
continue;
}
matrix[i][j] += std::min(switch_link_count[i], switch_link_count[j]);
}
}
return c10::make_intrusive<c10d::DMAConnectivity>(
c10::DeviceType::CUDA, "nvlink", std::move(matrix));
}
};
struct RegisterDetector {
RegisterDetector() {
register_dma_connectivity_detector(
c10::DeviceType::CUDA, "nvlink", c10::make_intrusive<NVLinkDetector>());
}
};
static RegisterDetector register_detector_;
} // namespace
#endif
|