1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <algorithm>
#include <array>
#include <cctype>
#include <exception>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<
std::pair<const char*, DeviceType>,
static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"ipu", DeviceType::IPU},
{"xpu", DeviceType::XPU},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
{"opencl", DeviceType::OPENCL},
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
{"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA},
{"ort", DeviceType::ORT},
{"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan},
{"mps", DeviceType::MPS},
{"meta", DeviceType::Meta},
{"hpu", DeviceType::HPU},
{"privateuseone", DeviceType::PrivateUse1},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[&device_string](const std::pair<const char*, DeviceType>& p) {
return p.first && p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
std::vector<const char*> device_names;
for (const auto& it : types) {
if (it.first) {
device_names.push_back(it.first);
}
}
TORCH_CHECK(
false,
"Expected one of ",
c10::Join(", ", device_names),
" device type at start of device string: ",
device_string);
}
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
} // namespace
Device::Device(const std::string& device_string) : Device(Type::CPU) {
TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
std::string device_name, device_index_str;
DeviceStringParsingState pstate = DeviceStringParsingState::START;
// The code below tries to match the string in the variable
// device_string against the regular expression:
// ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
for (size_t i = 0;
pstate != DeviceStringParsingState::ERROR && i < device_string.size();
++i) {
const char ch = device_string.at(i);
switch (pstate) {
case DeviceStringParsingState::START:
if (ch != ':') {
if (isalpha(ch) || ch == '_') {
device_name.push_back(ch);
} else {
pstate = DeviceStringParsingState::ERROR;
}
} else {
pstate = DeviceStringParsingState::INDEX_START;
}
break;
case DeviceStringParsingState::INDEX_START:
if (isdigit(ch)) {
device_index_str.push_back(ch);
pstate = DeviceStringParsingState::INDEX_REST;
} else {
pstate = DeviceStringParsingState::ERROR;
}
break;
case DeviceStringParsingState::INDEX_REST:
if (device_index_str.at(0) == '0') {
pstate = DeviceStringParsingState::ERROR;
break;
}
if (isdigit(ch)) {
device_index_str.push_back(ch);
} else {
pstate = DeviceStringParsingState::ERROR;
}
break;
case DeviceStringParsingState::ERROR:
// Execution won't reach here.
break;
}
}
const bool has_error = device_name.empty() ||
pstate == DeviceStringParsingState::ERROR ||
(pstate == DeviceStringParsingState::INDEX_START &&
device_index_str.empty());
TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
try {
if (!device_index_str.empty()) {
index_ = c10::stoi(device_index_str);
}
} catch (const std::exception&) {
TORCH_CHECK(
false,
"Could not parse device index '",
device_index_str,
"' in device string '",
device_string,
"'");
}
type_ = parse_type(device_name);
validate();
}
std::string Device::str() const {
std::string str = DeviceTypeName(type(), /* lower case */ true);
if (has_index()) {
str.push_back(':');
str.append(to_string(index()));
}
return str;
}
std::ostream& operator<<(std::ostream& stream, const Device& device) {
stream << device.str();
return stream;
}
} // namespace c10
|