File: CUDAGuardImpl.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (262 lines) | stat: -rw-r--r-- 9,183 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#pragma once

#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>

#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/PyInterpreter.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#include <optional>

namespace c10::cuda::impl {

struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  static constexpr DeviceType static_type = DeviceType::CUDA;

  CUDAGuardImpl() = default;
  explicit CUDAGuardImpl(DeviceType t) {
    TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
  }
  DeviceType type() const override {
    return DeviceType::CUDA;
  }
  Device exchangeDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.is_cuda());
    auto old_device_index = c10::cuda::ExchangeDevice(d.index());
    return Device(DeviceType::CUDA, old_device_index);
  }
  Device getDevice() const override {
    DeviceIndex device = 0;
    C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
    return Device(DeviceType::CUDA, device);
  }
  std::optional<Device> uncheckedGetDevice() const noexcept {
    DeviceIndex device{-1};
    const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
    C10_CUDA_CHECK_WARN(err);
    if (err != cudaSuccess) {
      return std::nullopt;
    }
    return Device(DeviceType::CUDA, device);
  }
  void setDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.is_cuda());
    C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
  }
  void uncheckedSetDevice(Device d) const noexcept override {
    C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
  }
  Stream getStream(Device d) const noexcept override {
    return getCurrentCUDAStream(d.index()).unwrap();
  }
  Stream getDefaultStream(Device d) const override {
    return getDefaultCUDAStream(d.index());
  }
  Stream getNewStream(Device d, int priority = 0) const override {
    return getStreamFromPool(priority, d.index());
  }
  Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
      const override {
    return getStreamFromPool(isHighPriority, d.index());
  }
  // NB: These do NOT set the current device
  Stream exchangeStream(Stream s) const noexcept override {
    CUDAStream cs(s);
    auto old_stream = getCurrentCUDAStream(s.device().index());
    setCurrentCUDAStream(cs);
    return old_stream.unwrap();
  }
  DeviceIndex deviceCount() const noexcept override {
    return device_count();
  }

  // Event-related functions
  void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
    // Maps PyTorch's Event::Flag to CUDA flag
    auto cuda_flag = cudaEventDefault;
    switch (flag) {
      case EventFlag::PYTORCH_DEFAULT:
        cuda_flag = cudaEventDisableTiming;
        break;
      case EventFlag::BACKEND_DEFAULT:
        cuda_flag = cudaEventDefault;
        break;
      default:
        TORCH_CHECK(false, "CUDA event received unknown flag");
    }

    C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_creation(
          c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
    }
  }

  void destroyEvent(void* event, const DeviceIndex device_index)
      const noexcept override {
    if (!event)
      return;
    auto cuda_event = static_cast<cudaEvent_t>(event);
    DeviceIndex orig_device{-1};
    C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
    C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_deletion(
          c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
    }
    C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
    C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
  }

  void record(
      void** event,
      const Stream& stream,
      const DeviceIndex device_index,
      const EventFlag flag) const override {
    TORCH_CHECK(
        device_index == -1 || device_index == stream.device_index(),
        "Event device index ",
        device_index,
        " does not match recording stream's device index ",
        stream.device_index(),
        ".");

    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
    CUDAStream cuda_stream{stream};

    // Moves to stream's device to record
    const auto orig_device = getDevice();
    setDevice(stream.device());

    // Creates the event (lazily)
    if (!cuda_event)
      createEvent(&cuda_event, flag);
    C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
    // Makes the void* point to the (possibly just allocated) CUDA event
    *event = cuda_event;
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_record(
          c10::kCUDA,
          reinterpret_cast<uintptr_t>(cuda_event),
          reinterpret_cast<uintptr_t>(cuda_stream.stream()));
    }

    // Resets device
    setDevice(orig_device);
  }

  void block(void* event, const Stream& stream) const override {
    if (!event)
      return;
    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
    CUDAStream cuda_stream{stream};
    const auto orig_device = getDevice();
    setDevice(stream.device());
    C10_CUDA_CHECK(cudaStreamWaitEvent(
        cuda_stream,
        cuda_event,
        /*flags (must be zero)=*/0));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_wait(
          c10::kCUDA,
          reinterpret_cast<uintptr_t>(cuda_event),
          reinterpret_cast<uintptr_t>(cuda_stream.stream()));
    }
    setDevice(orig_device);
  }

  // May be called from any device
  bool queryEvent(void* event) const override {
    if (!event)
      return true;
    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
    // Note: cudaEventQuery can be safely called from any device
    const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
    if (err != cudaErrorNotReady) {
      C10_CUDA_CHECK(err);
    } else {
      // ignore and clear the error if not ready
      (void)cudaGetLastError();
    }
    return (err == cudaSuccess);
  }

  // Stream-related functions
  bool queryStream(const Stream& stream) const override {
    CUDAStream cuda_stream{stream};
    return cuda_stream.query();
  }

  void synchronizeStream(const Stream& stream) const override {
    CUDAStream cuda_stream{stream};
    cuda_stream.synchronize();
  }

  void synchronizeEvent(void* event) const override {
    if (!event)
      return;
    cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_event_synchronization(
          c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
    }
    // Note: cudaEventSynchronize can be safely called from any device
    C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
  }

  // Note: synchronizeDevice can be safely called from any device
  void synchronizeDevice(const c10::DeviceIndex device_index) const override {
    DeviceIndex orig_device{-1};
    C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
    C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
    if (C10_UNLIKELY(interp)) {
      (*interp)->trace_gpu_device_synchronization(c10::kCUDA);
    }
    C10_CUDA_CHECK(cudaDeviceSynchronize());
    C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
  }

  void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
      const override {
    CUDAStream cuda_stream{stream};
    CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
  }

  double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
      const override {
    TORCH_CHECK(
        event1 && event2,
        "Both events must be recorded before calculating elapsed time.");
    // Even though cudaEventElapsedTime can be safely called from any device, if
    // the current device is not initialized, it will create a new cuda context,
    // which will consume a lot of memory.
    DeviceIndex orig_device{-1};
    C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
    C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
    cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1);
    cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2);
    float time_ms = 0;
    // raise cudaErrorNotReady if either event is recorded but not yet completed
    C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2));
    C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
    return static_cast<double>(time_ms);
  }
};

} // namespace c10::cuda::impl