File: CudaIPCTypes.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (265 lines) | stat: -rw-r--r-- 8,794 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#include <ATen/MapAllocator.h>
#include <torch/csrc/CudaIPCTypes.h>
#include <map>
#include <mutex>
#include <random>
#include <string>

namespace torch {

namespace {

void warnProducerTerminatedBeforeSharedTensorsReleased() {
  static bool warned = false;
  if (!warned) {
    LOG(WARNING)
        << "Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]";
    warned = true;
  }
}

struct CudaIPCGlobalEntities {
  // This class is used as a singleton (see cuda_ipc_global_entities)
  // This variable is used to track its lifetime to avoid accessing it
  // after it was destroyed which would lead to segmentation faults
  // Note that a trvial type is used which doesn't suffer from construction
  // and destruction order issues
  static bool alive;

  std::mutex ref_counters_mutex_;
  std::atomic<int64_t> sync_events_used_{0};
  std::map<std::string, std::shared_ptr<CudaIPCRefCountersFile>>
      ref_counters_files_;
  std::shared_ptr<CudaIPCRefCountersFile> next_available_ref_counters_file_;
  CudaIPCSentDataLimbo CudaIPCSentDataLimbo_;
  CudaIPCGlobalEntities() {
    alive = true;
  }
  ~CudaIPCGlobalEntities() {
    CudaIPCSentDataLimbo_.collect();
    safe_clean_current_file();
    if (next_available_ref_counters_file_) {
      warnProducerTerminatedBeforeSharedTensorsReleased();
    }
    alive = false;
  }
  void safe_clean_current_file() {
    std::lock_guard<std::mutex> lock(ref_counters_mutex_);
    if (next_available_ref_counters_file_ &&
        next_available_ref_counters_file_->offsets_in_use() == 0) {
      ref_counters_files_.erase(next_available_ref_counters_file_->handle());
      next_available_ref_counters_file_.reset();
    }
  }
};

bool CudaIPCGlobalEntities::alive = false;
CudaIPCGlobalEntities cuda_ipc_global_entities;

CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() {
  collect();
  if (size() > 0) {
    warnProducerTerminatedBeforeSharedTensorsReleased();
  }
}

bool CudaIPCSentDataLimbo::collect() {
  bool freed_memory = false;
  std::vector<std::unique_ptr<CudaIPCSentData>> reset_blocks;
  { // Begin critical section to modify shared blocks
    std::lock_guard<std::mutex> lock(limbo_mutex_);
    std::vector<std::unique_ptr<CudaIPCSentData>> kept_blocks;
    for (auto& sd : shared_blocks_) {
      if (sd->counter_value() > 0) {
        kept_blocks.push_back(std::move(sd));
      } else {
        freed_memory = true;
        reset_blocks.push_back(std::move(sd));
      }
    }
    shared_blocks_ = std::move(kept_blocks);
  }
  // Need to reset blocks out of the critical section here, otherwise it
  // deadlocks.
  for (auto& sd : reset_blocks) {
    sd.reset();
  }
  return freed_memory;
}

void CudaIPCSentDataLimbo::add(std::unique_ptr<CudaIPCSentData> shared_block) {
  std::lock_guard<std::mutex> lock(limbo_mutex_);
  static bool warned = false;
  if (shared_blocks_.size() > CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO &&
      !warned) {
    LOG(WARNING)
        << "Producer process tried to deallocate over "
        << CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO
        << " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. "
        << "We assume it will never going to be the case, but if it is, please file but to https://github.com/pytorch/pytorch";
    warned = true;
  }
  shared_blocks_.push_back(std::move(shared_block));
}

uint64_t CudaIPCSentDataLimbo::size() {
  std::lock_guard<std::mutex> lock(limbo_mutex_);
  return shared_blocks_.size();
}

void CudaIPCSentDataDelete(void* ptr) {
  std::unique_ptr<CudaIPCSentData> sent_data(
      static_cast<CudaIPCSentData*>(ptr));
  if (!CudaIPCGlobalEntities::alive) {
    return;
  }
  if (sent_data->counter_value() > 0) {
    cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data));
  }
  cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
}

void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) {
  if (!CudaIPCGlobalEntities::alive) {
    return;
  }
  std::lock_guard<std::mutex> lock(
      cuda_ipc_global_entities.ref_counters_mutex_);
  auto& map = cuda_ipc_global_entities.ref_counters_files_;
  auto it = map.find(handle);
  if (it != map.end()) {
    it->second->return_offset(offset);
    if (it->second->offsets_in_use() == 0 && !it->second->have_offsets()) {
      map.erase(handle);
    }
  }
}

} // namespace

CudaIPCSentData::CudaIPCSentData(
    std::string handle,
    int64_t offset,
    int64_t* counter_ptr,
    at::Device device)
    : handle_(std::move(handle)),
      offset_(offset),
      counter_ptr_(counter_ptr),
      original_ptr_(),
      device_(device) {
#if !defined(USE_ROCM)
  // CUDA have the unofficial limit on the number of recorded blocking
  // interprocess events, to prevent using of all events, we are switching to
  // StreamSync before limit reached.
  //
  //  ```python
  //  import torch
  //  a = [ torch.cuda.Event(
  //      enable_timing=False, blocking=True, interprocess=True) for i in
  //      range(30000) ]
  //  [i.record() for i in a]
  //  ```
  //
  if (cuda_ipc_global_entities.sync_events_used_.load() <
      CUDA_IPC_MAXIMUM_EVENTS_TO_USE) {
    // TODO: More efficient would be to create event inside of main thread (at
    // the moment of the queue.put). The reason this is more efficient is
    // because the main thread may have queued extra work on the stream, which
    // this event will consequently wait for (uselessly).
    cuda_ipc_global_entities.sync_events_used_++;
    C10_CUDA_CHECK(cudaEventCreateWithFlags(
        &event_,
        cudaEventDisableTiming | cudaEventInterprocess |
            cudaEventBlockingSync));
    C10_CUDA_CHECK(cudaEventRecord(
        event_, c10::cuda::getCurrentCUDAStream(device.index())));
    event_sync_required_ = true;
  } else {
    auto stream = c10::cuda::getCurrentCUDAStream(device.index());
    at::cuda::stream_synchronize(stream);
    event_ = nullptr;
    event_sync_required_ = false;
  }
#else
  // cuIpcGetEventHandle with HIP is not supported, so we have to sync
  // stream instead of passing event
  auto stream = c10::cuda::getCurrentCUDAStream(device.index());
  at::cuda::stream_synchronize(stream);
  event_sync_required_ = false;
#endif
}

CudaIPCSentData::~CudaIPCSentData() {
  ReturnRefCounter(handle_, offset_);
#if !defined(USE_ROCM)
  try {
    if (event_sync_required_) {
      at::cuda::CUDAGuard device_guard(device_.index());
      cudaEventDestroy(event_);
      if (!CudaIPCGlobalEntities::alive) {
        return;
      }
      cuda_ipc_global_entities.sync_events_used_--;
    }
  } catch (...) { /* No throw */
  }
#endif
}

int64_t CudaIPCSentData::counter_value() {
  return *counter_ptr_;
}

at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) {
  {
    std::lock_guard<std::mutex> lock(
        cuda_ipc_global_entities.ref_counters_mutex_);
    if (!cuda_ipc_global_entities.next_available_ref_counters_file_) {
      std::string ref_counter_handle = at::NewProcessWideShmHandle();

      int flags =
          at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
      at::DataPtr sptr = at::RefcountedMapAllocator::makeDataPtr(
          ref_counter_handle.c_str(),
          flags,
          sizeof(int64_t) * CUDA_IPC_REF_COUNTER_FILE_SIZE,
          nullptr);
      auto rc = std::make_shared<CudaIPCRefCountersFile>(
          ref_counter_handle, CUDA_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr));
      cuda_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc;
      cuda_ipc_global_entities.next_available_ref_counters_file_ = rc;
    }
  }
  cuda_ipc_global_entities.next_available_ref_counters_file_->set_counter(1);
  auto sent_data = new CudaIPCSentData(
      cuda_ipc_global_entities.next_available_ref_counters_file_->handle(),
      cuda_ipc_global_entities.next_available_ref_counters_file_->get_offset(),
      cuda_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(),
      device);

  cuda_ipc_global_entities.next_available_ref_counters_file_->rotate_offset();
  if (!cuda_ipc_global_entities.next_available_ref_counters_file_
           ->have_offsets()) {
    cuda_ipc_global_entities.next_available_ref_counters_file_.reset();
  }
  return at::DataPtr(data, sent_data, CudaIPCSentDataDelete, device);
}

bool CudaIPCCollect() {
  if (!CudaIPCGlobalEntities::alive) {
    return true;
  }
  bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
  if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) {
    cuda_ipc_global_entities.safe_clean_current_file();
  }
  return freed_memory;
}

} // namespace torch

namespace c10 {
namespace {
REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
}
} // namespace c10