1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
AOTITorchError aoti_torch_create_cuda_guard(
int32_t device_index,
CUDAGuardHandle* ret_guard // returns new reference
) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::cuda::CUDAGuard* guard = new at::cuda::CUDAGuard(device_index);
*ret_guard = reinterpret_cast<CUDAGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_cuda_guard(CUDAGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::cuda::CUDAGuard*>(guard); });
}
AOTITorchError aoti_torch_cuda_guard_set_index(
CUDAGuardHandle guard,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
reinterpret_cast<at::cuda::CUDAGuard*>(guard)->set_index(device_index);
});
}
AOTITorchError aoti_torch_create_cuda_stream_guard(
void* stream,
int32_t device_index,
CUDAStreamGuardHandle* ret_guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::cuda::CUDAStreamGuard* guard =
new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
static_cast<cudaStream_t>(stream), device_index));
*ret_guard = reinterpret_cast<CUDAStreamGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_cuda_stream_guard(
CUDAStreamGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::cuda::CUDAStreamGuard*>(guard); });
}
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
});
}
|