1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
|
#pragma once
#include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/xpu/XPUCachingAllocator.h>
#include <c10/xpu/XPUFunctions.h>
#include <c10/xpu/XPUStream.h>
#include <vector>
namespace c10::xpu::impl {
struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = kXPU;
XPUGuardImpl() = default;
explicit XPUGuardImpl(DeviceType t) {
TORCH_INTERNAL_ASSERT(t == kXPU);
}
DeviceType type() const override {
return kXPU;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_xpu());
const auto old_device_index = c10::xpu::exchange_device(d.index());
return Device(kXPU, old_device_index);
}
Device getDevice() const override {
const auto device = c10::xpu::current_device();
return Device(kXPU, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_xpu());
c10::xpu::set_device(d.index());
}
void uncheckedSetDevice(Device d) const noexcept override {
c10::xpu::set_device(d.index());
}
Stream getStream(Device d) const noexcept override {
return getCurrentXPUStream(d.index()).unwrap();
}
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPool(priority, d.index());
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return getStreamFromPool(isHighPriority, d.index());
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream s) const noexcept override {
const XPUStream stream(s);
const auto old_stream = getCurrentXPUStream(s.device().index());
setCurrentXPUStream(stream);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
return c10::xpu::device_count();
}
// Event-related functions
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
if (!event)
return;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
c10::kXPU, reinterpret_cast<uintptr_t>(event));
}
delete reinterpret_cast<sycl::event*>(event);
}
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(
device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
const XPUStream xpu_stream{stream};
// Delete the event previously recorded.
if (xpu_event)
delete xpu_event;
#if SYCL_COMPILER_VERSION >= 20250000
if (flag == EventFlag::BACKEND_DEFAULT) {
// Use the profiling tag to record the event to enable timing feature.
xpu_event =
new sycl::event(sycl::ext::oneapi::experimental::submit_profiling_tag(
xpu_stream.queue()));
} else {
xpu_event =
new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
}
#else
xpu_event = new sycl::event(xpu_stream.queue().ext_oneapi_submit_barrier());
#endif
*event = reinterpret_cast<void*>(xpu_event);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kXPU,
reinterpret_cast<uintptr_t>(xpu_event),
reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
}
}
void block(void* event, const Stream& stream) const override {
if (!event)
return;
auto* xpu_event = reinterpret_cast<sycl::event*>(event);
std::vector<sycl::event> event_list{*xpu_event};
const XPUStream xpu_stream(stream);
xpu_stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kXPU,
reinterpret_cast<uintptr_t>(xpu_event),
reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
}
}
bool queryEvent(void* event) const override {
using namespace sycl::info;
if (!event)
return true;
auto* xpu_event = reinterpret_cast<sycl::event*>(event);
return xpu_event->get_info<event::command_execution_status>() ==
event_command_status::complete;
}
double elapsedTime(
void* start_event,
void* end_event,
const DeviceIndex device_index) const override {
#if SYCL_COMPILER_VERSION < 20250000
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"elapsedTime requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
TORCH_CHECK(
start_event && end_event,
"Both events must be recorded before calculating elapsed time.");
auto* xpu_start_event = reinterpret_cast<sycl::event*>(start_event);
auto* xpu_end_event = reinterpret_cast<sycl::event*>(end_event);
using namespace sycl::info::event_profiling;
// Block until both of the recorded events are completed.
uint64_t end_time_ns = xpu_end_event->get_profiling_info<command_end>();
uint64_t start_time_ns = xpu_start_event->get_profiling_info<command_end>();
// Return the eplased time in milliseconds.
return 1e-6 *
(static_cast<double>(end_time_ns) - static_cast<double>(start_time_ns));
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
const XPUStream xpu_stream{stream};
return xpu_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
const XPUStream xpu_stream{stream};
xpu_stream.synchronize();
}
void synchronizeEvent(void* event) const override {
if (!event)
return;
auto* xpu_event = reinterpret_cast<sycl::event*>(event);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event));
}
xpu_event->wait_and_throw();
}
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_device_synchronization(c10::kXPU);
}
c10::xpu::syncStreamsOnDevice(device_index);
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
const XPUStream xpu_stream{stream};
XPUCachingAllocator::recordStream(data_ptr, xpu_stream);
}
};
} // namespace c10::xpu::impl
|