File: device_utils.h

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (67 lines) | stat: -rw-r--r-- 2,320 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#pragma once

// WARNING: Be careful when adding new includes here. This header will be used
// in model.so, and should not refer to any aten/c10 headers except the stable
// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
// applies to other files under torch/csrc/inductor/aoti_runtime/.

#ifdef USE_CUDA

// FIXME: Currently, CPU and CUDA backend are mutually exclusive.
// This is a temporary workaround. We need a better way to support
// multi devices.

#include <cuda.h>
#include <cuda_runtime_api.h>

#define AOTI_RUNTIME_DEVICE_CHECK(EXPR)                    \
  do {                                                     \
    const cudaError_t code = EXPR;                         \
    const char* msg = cudaGetErrorString(code);            \
    if (code != cudaSuccess) {                             \
      throw std::runtime_error(                            \
          std::string("CUDA error: ") + std::string(msg)); \
    }                                                      \
  } while (0)

namespace torch::aot_inductor {

using DeviceStreamType = cudaStream_t;

} // namespace torch::aot_inductor

#elif defined(USE_XPU)
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <sstream>
#define AOTI_RUNTIME_DEVICE_CHECK(EXPR)                                   \
  do {                                                                    \
    const ze_result_t status = EXPR;                                      \
    if (status != ZE_RESULT_SUCCESS) {                                    \
      std::stringstream ss;                                               \
      ss << "L0 runtime error: " << std::hex << std::uppercase << status; \
      throw std::runtime_error(ss.str());                                 \
    }                                                                     \
  } while (0)

namespace torch::aot_inductor {

using DeviceStreamType = sycl::queue*;

} // namespace torch::aot_inductor

#else

#define AOTI_RUNTIME_DEVICE_CHECK(EXPR)            \
  bool ok = EXPR;                                  \
  if (!ok) {                                       \
    throw std::runtime_error("CPU runtime error"); \
  }

namespace torch::aot_inductor {

using DeviceStreamType = void*;

} // namespace torch::aot_inductor

#endif // USE_CUDA