1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>
#include <cuda_runtime_api.h>
namespace c10 {
namespace cuda {
namespace impl {
struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::CUDA;
CUDAGuardImpl() {}
explicit CUDAGuardImpl(DeviceType t) {
TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA);
}
DeviceType type() const override {
return DeviceType::CUDA;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA);
Device old_device = getDevice();
if (old_device.index() != d.index()) {
C10_CUDA_CHECK(cudaSetDevice(d.index()));
}
return old_device;
}
Device getDevice() const override {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
return Device(DeviceType::CUDA, device);
}
c10::optional<Device> uncheckedGetDevice() const noexcept {
int device;
auto err = cudaGetDevice(&device);
C10_CUDA_CHECK_WARN(err);
if (err != cudaSuccess) {
return c10::nullopt;
}
return Device(DeviceType::CUDA, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA);
Device current_device = getDevice();
if (current_device != d) {
C10_CUDA_CHECK(cudaSetDevice(d.index()));
}
}
void uncheckedSetDevice(Device d) const noexcept override {
auto current_device = uncheckedGetDevice();
if (!current_device.has_value() || current_device.value() != d) {
C10_CUDA_CHECK_WARN(cudaSetDevice(d.index()));
}
}
Stream getStream(Device d) const noexcept override {
return getCurrentCUDAStream(d.index()).unwrap();
}
Stream getDefaultStream(Device d) const override {
return getDefaultCUDAStream(d.index());
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
CUDAStream cs(s);
auto old_stream = getCurrentCUDAStream(s.device().index());
setCurrentCUDAStream(cs);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
return device_count();
}
// Event-related functions
void createEvent(
cudaEvent_t* cuda_event,
const EventFlag flag) const {
// Maps PyTorch's Event::Flag to CUDA flag
auto cuda_flag = cudaEventDefault;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
case EventFlag::CUDA_EVENT_DISABLE_TIMING:
cuda_flag = cudaEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
case EventFlag::CUDA_EVENT_DEFAULT:
cuda_flag = cudaEventDefault;
break;
default:
TORCH_CHECK(false, "CUDA event received unknown flag");
}
C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag));
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override {
if (!event) return;
auto cuda_event = static_cast<cudaEvent_t>(event);
int orig_device;
C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device));
C10_CUDA_CHECK_WARN(cudaSetDevice(device_index));
C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event));
C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device));
}
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event);
CUDAStream cuda_stream{stream};
// Moves to stream's device to record
const auto orig_device = getDevice();
setDevice(stream.device());
// Creates the event (lazily)
if (!cuda_event) createEvent(&cuda_event, flag);
C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream));
// Makes the void* point to the (possibly just allocated) CUDA event
*event = cuda_event;
// Resets device
setDevice(orig_device);
}
void block(
void* event,
const Stream& stream) const override {
if (!event) return;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
CUDAStream cuda_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_CUDA_CHECK(cudaStreamWaitEvent(
cuda_stream,
cuda_event,
/*flags (must be zero)=*/ 0));
setDevice(orig_device);
}
// May be called from any device
bool queryEvent(void* event) const override {
if (!event) return true;
cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event);
const cudaError_t err = cudaEventQuery(cuda_event);
if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
}
return (err == cudaSuccess);
}
};
}}} // namespace c10::cuda::impl
|