1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <caffe2/proto/caffe2.pb.h>
namespace caffe2 {
using DeviceType = at::DeviceType;
constexpr DeviceType CPU = DeviceType::CPU;
constexpr DeviceType CUDA = DeviceType::CUDA;
constexpr DeviceType OPENGL = DeviceType::OPENGL;
constexpr DeviceType OPENCL = DeviceType::OPENCL;
constexpr DeviceType MKLDNN = DeviceType::MKLDNN;
constexpr DeviceType IDEEP = DeviceType::IDEEP;
constexpr DeviceType HIP = DeviceType::HIP;
constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES =
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
constexpr DeviceType ONLY_FOR_TEST = DeviceType::ONLY_FOR_TEST;
inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
switch (p) {
case caffe2::PROTO_CPU:
return DeviceType::CPU;
case caffe2::PROTO_CUDA:
return DeviceType::CUDA;
case caffe2::PROTO_OPENGL:
return DeviceType::OPENGL;
case caffe2::PROTO_OPENCL:
return DeviceType::OPENCL;
case caffe2::PROTO_MKLDNN:
return DeviceType::MKLDNN;
case caffe2::PROTO_IDEEP:
return DeviceType::IDEEP;
case caffe2::PROTO_HIP:
return DeviceType::HIP;
case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES:
return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
case caffe2::PROTO_ONLY_FOR_TEST:
return DeviceType::ONLY_FOR_TEST;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(p),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline CAFFE2_API DeviceType ProtoToType(int p) {
return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
}
inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) {
switch (t) {
case DeviceType::CPU:
return caffe2::PROTO_CPU;
case DeviceType::CUDA:
return caffe2::PROTO_CUDA;
case DeviceType::OPENGL:
return caffe2::PROTO_OPENGL;
case DeviceType::OPENCL:
return caffe2::PROTO_OPENCL;
case DeviceType::MKLDNN:
return caffe2::PROTO_MKLDNN;
case DeviceType::IDEEP:
return caffe2::PROTO_IDEEP;
case DeviceType::HIP:
return caffe2::PROTO_HIP;
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
case DeviceType::ONLY_FOR_TEST:
return caffe2::PROTO_ONLY_FOR_TEST;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(t),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline CAFFE2_API caffe2::DeviceOption DeviceToOption(
const at::Device& device) {
caffe2::DeviceOption option;
auto type = device.type();
option.set_device_type(TypeToProto(type));
switch (type) {
case DeviceType::CPU:
if (device.index() != -1) {
option.set_numa_node_id(device.index());
}
break;
case DeviceType::CUDA:
case DeviceType::HIP:
option.set_device_id(device.index());
break;
case DeviceType::OPENGL:
case DeviceType::OPENCL:
case DeviceType::MKLDNN:
case DeviceType::IDEEP:
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
case DeviceType::ONLY_FOR_TEST:
break;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(type),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
return option;
}
inline CAFFE2_API at::Device OptionToDevice(const caffe2::DeviceOption option) {
auto type = option.device_type();
int32_t id = -1;
switch (type) {
case caffe2::PROTO_CPU:
if (option.has_numa_node_id()) {
id = option.numa_node_id();
}
break;
case caffe2::PROTO_CUDA:
case caffe2::PROTO_HIP:
id = option.device_id();
break;
}
return at::Device(ProtoToType(type), id);
}
inline void ExtractDeviceOption(
DeviceOption* device_option,
const at::Device& device) {
AT_ASSERT(device_option);
device_option->CopyFrom(DeviceToOption(device));
}
} // namespace caffe2
|