1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <caffe2/proto/caffe2.pb.h>
namespace caffe2 {
using DeviceType = at::DeviceType;
constexpr DeviceType CPU = DeviceType::CPU;
constexpr DeviceType CUDA = DeviceType::CUDA;
constexpr DeviceType OPENGL = DeviceType::OPENGL;
constexpr DeviceType OPENCL = DeviceType::OPENCL;
constexpr DeviceType MKLDNN = DeviceType::MKLDNN;
constexpr DeviceType IDEEP = DeviceType::IDEEP;
constexpr DeviceType HIP = DeviceType::HIP;
constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES =
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
inline TORCH_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
switch (p) {
case caffe2::PROTO_CPU:
return DeviceType::CPU;
case caffe2::PROTO_CUDA:
return DeviceType::CUDA;
case caffe2::PROTO_OPENGL:
return DeviceType::OPENGL;
case caffe2::PROTO_OPENCL:
return DeviceType::OPENCL;
case caffe2::PROTO_MKLDNN:
return DeviceType::MKLDNN;
case caffe2::PROTO_IDEEP:
return DeviceType::IDEEP;
case caffe2::PROTO_HIP:
return DeviceType::HIP;
case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES:
return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(p),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline TORCH_API DeviceType ProtoToType(int p) {
return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
}
inline TORCH_API DeviceTypeProto TypeToProto(const DeviceType& t) {
switch (t) {
case DeviceType::CPU:
return caffe2::PROTO_CPU;
case DeviceType::CUDA:
return caffe2::PROTO_CUDA;
case DeviceType::OPENGL:
return caffe2::PROTO_OPENGL;
case DeviceType::OPENCL:
return caffe2::PROTO_OPENCL;
case DeviceType::MKLDNN:
return caffe2::PROTO_MKLDNN;
case DeviceType::IDEEP:
return caffe2::PROTO_IDEEP;
case DeviceType::HIP:
return caffe2::PROTO_HIP;
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(t),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline TORCH_API caffe2::DeviceOption DeviceToOption(const at::Device& device) {
caffe2::DeviceOption option;
auto type = device.type();
option.set_device_type(TypeToProto(type));
switch (type) {
case DeviceType::CPU:
if (device.index() != -1) {
option.set_numa_node_id(device.index());
}
break;
case DeviceType::CUDA:
case DeviceType::HIP:
option.set_device_id(device.index());
break;
case DeviceType::OPENGL:
case DeviceType::OPENCL:
case DeviceType::MKLDNN:
case DeviceType::IDEEP:
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
break;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(type),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
return option;
}
inline TORCH_API at::Device OptionToDevice(const caffe2::DeviceOption& option) {
auto type = option.device_type();
c10::DeviceIndex id = -1;
switch (type) {
case caffe2::PROTO_CPU:
if (option.has_numa_node_id()) {
id = static_cast<c10::DeviceIndex>(option.numa_node_id());
}
break;
case caffe2::PROTO_CUDA:
case caffe2::PROTO_HIP:
id = static_cast<c10::DeviceIndex>(option.device_id());
break;
}
return at::Device(ProtoToType(type), id);
}
inline void ExtractDeviceOption(
DeviceOption* device_option,
const at::Device& device) {
AT_ASSERT(device_option);
device_option->CopyFrom(DeviceToOption(device));
}
} // namespace caffe2
|