1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
|
#include <torch/csrc/inductor/aoti_torch/c/shim_xpu.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/DeviceType.h>
#include <c10/core/StreamGuard.h>
#include <c10/xpu/XPUStream.h>
AOTITorchError aoti_torch_create_xpu_guard(
int32_t device_index,
XPUGuardHandle* ret_guard // returns new reference
) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::DeviceGuard* guard =
new at::DeviceGuard(at::Device(at::DeviceType::XPU, device_index));
*ret_guard = reinterpret_cast<XPUGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_xpu_guard(XPUGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::DeviceGuard*>(guard); });
}
AOTITorchError aoti_torch_xpu_guard_set_index(
XPUGuardHandle guard,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ reinterpret_cast<at::DeviceGuard*>(guard)->set_index(device_index); });
}
AOTITorchError aoti_torch_create_xpu_stream_guard(
void* stream,
int32_t device_index,
XPUStreamGuardHandle* ret_guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
assert(stream);
at::StreamGuard* guard =
new at::StreamGuard(at::xpu::getStreamFromExternal(
static_cast<sycl::queue*>(stream), device_index)
.unwrap());
*ret_guard = reinterpret_cast<XPUStreamGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_xpu_stream_guard(XPUStreamGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::StreamGuard*>(guard); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_stream(int32_t device_index, void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_stream = &(at::xpu::getCurrentXPUStream(device_index).queue()); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_xpu_device(int32_t* device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *device_index = static_cast<int32_t>(c10::xpu::current_device()); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_set_current_xpu_device(const int32_t& device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ c10::xpu::set_device(static_cast<int8_t>(device_index)); });
}
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
int32_t device_index = static_cast<int32_t>(c10::xpu::current_device());
*ret = &(at::xpu::getCurrentXPUStream(device_index).queue());
});
}
|