1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
|
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/impl/PyInterpreter.h>
#include <cuda_runtime_api.h>
#include <cstdint>
#include <optional>
namespace c10::cuda::impl {
struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::CUDA;
CUDAGuardImpl() = default;
explicit CUDAGuardImpl(DeviceType t) {
TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
}
DeviceType type() const override {
return DeviceType::CUDA;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
auto old_device_index = c10::cuda::ExchangeDevice(d.index());
return Device(DeviceType::CUDA, old_device_index);
}
Device getDevice() const override {
DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return Device(DeviceType::CUDA, device);
}
std::optional<Device> uncheckedGetDevice() const noexcept {
DeviceIndex device{-1};
const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device));
C10_CUDA_CHECK_WARN(err);
if (err != cudaSuccess) {
return std::nullopt;
}
return Device(DeviceType::CUDA, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
C10_CUDA_CHECK(c10::cuda::SetDevice(d.index()));
}
void uncheckedSetDevice(Device d) const noexcept override {
C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index()));
}
Stream getStream(Device d) const noexcept override {
return getCurrentCUDAStream(d.index()).unwrap();
}
Stream getDefaultStream(Device d) const override {
return getDefaultCUDAStream(d.index());
}
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPool(priority, d.index());
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return getStreamFromPool(isHighPriority, d.index());
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
CUDAStream cs(s);
auto old_stream = getCurrentCUDAStream(s.device().index());
setCurrentCUDAStream(cs);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
return device_count();
}
// Event-related functions
void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const {
// Maps PyTorch's Event::Flag to CUDA flag
auto cuda_flag = cudaEventDefault;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
cuda_flag = cudaEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
cuda_flag = cudaEventDefault;
break;
default:
TORCH_CHECK(false, "CUDA event received unknown flag");
}
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
if (!event)
return;
auto cuda_event = static_cast<cudaEvent_t>(event);
DeviceIndex orig_device{-1};
C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device));
}
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(
device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
CUDAStream cuda_stream{stream};
// Moves to stream's device to record
const auto orig_device = getDevice();
setDevice(stream.device());
// Creates the event (lazily)
if (!cuda_event)
createEvent(&cuda_event, flag);
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
// Makes the void* point to the (possibly just allocated) CUDA event
*event = cuda_event;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
// Resets device
setDevice(orig_device);
}
void block(void* event, const Stream& stream) const override {
if (!event)
return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
CUDAStream cuda_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_CUDA_CHECK(cudaStreamWaitEvent(
cuda_stream,
cuda_event,
/*flags (must be zero)=*/0));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kCUDA,
reinterpret_cast<uintptr_t>(cuda_event),
reinterpret_cast<uintptr_t>(cuda_stream.stream()));
}
setDevice(orig_device);
}
// May be called from any device
bool queryEvent(void* event) const override {
if (!event)
return true;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
// Note: cudaEventQuery can be safely called from any device
const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event));
if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
} else {
// ignore and clear the error if not ready
(void)cudaGetLastError();
}
return (err == cudaSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
return cuda_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
CUDAStream cuda_stream{stream};
cuda_stream.synchronize();
}
void synchronizeEvent(void* event) const override {
if (!event)
return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kCUDA, reinterpret_cast<uintptr_t>(cuda_event));
}
// Note: cudaEventSynchronize can be safely called from any device
C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
}
// Note: synchronizeDevice can be safely called from any device
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
DeviceIndex orig_device{-1};
C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
}
C10_CUDA_CHECK(cudaDeviceSynchronize());
C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
CUDAStream cuda_stream{stream};
CUDACachingAllocator::recordStream(data_ptr, cuda_stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
TORCH_CHECK(
event1 && event2,
"Both events must be recorded before calculating elapsed time.");
// Even though cudaEventElapsedTime can be safely called from any device, if
// the current device is not initialized, it will create a new cuda context,
// which will consume a lot of memory.
DeviceIndex orig_device{-1};
C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
cudaEvent_t cuda_event1 = static_cast<cudaEvent_t>(event1);
cudaEvent_t cuda_event2 = static_cast<cudaEvent_t>(event2);
float time_ms = 0;
// raise cudaErrorNotReady if either event is recorded but not yet completed
C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, cuda_event1, cuda_event2));
C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
return static_cast<double>(time_ms);
}
};
} // namespace c10::cuda::impl
|