1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
|
#pragma once
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <functional>
#include <iosfwd>
#include <string>
namespace c10 {
/// An index representing a specific device; e.g., the 1 in GPU 1.
/// A DeviceIndex is not independently meaningful without knowing
/// the DeviceType it is associated; try to use Device rather than
/// DeviceIndex directly.
using DeviceIndex = int8_t;
/// Represents a a compute device on which a tensor is located. A device is
/// uniquely identified by a type, which specifies the type of machine it is
/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the
/// specific compute device when there is more than one of a certain type. The
/// device index is optional, and in its defaulted state represents (abstractly)
/// "the current device". Further, there are two constraints on the value of the
/// device index, if one is explicitly stored:
/// 1. A negative index represents the current device, a non-negative index
/// represents a specific, concrete device,
/// 2. When the device type is CPU, the device index must be zero.
struct C10_API Device final {
using Type = DeviceType;
/// Constructs a new `Device` from a `DeviceType` and an optional device
/// index.
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {
validate();
}
/// Constructs a `Device` from a string description, for convenience.
/// The string supplied must follow the following schema:
/// `(cpu|cuda)[:<device-index>]`
/// where `cpu` or `cuda` specifies the device type, and
/// `:<device-index>` optionally specifies a device index.
/* implicit */ Device(const std::string& device_string);
/// Returns true if the type and index of this `Device` matches that of
/// `other`.
bool operator==(const Device& other) const noexcept {
return this->type_ == other.type_ && this->index_ == other.index_;
}
/// Returns true if the type or index of this `Device` differs from that of
/// `other`.
bool operator!=(const Device& other) const noexcept {
return !(*this == other);
}
/// Sets the device index.
void set_index(DeviceIndex index) {
index_ = index;
}
/// Returns the type of device this is.
DeviceType type() const noexcept {
return type_;
}
/// Returns the optional index.
DeviceIndex index() const noexcept {
return index_;
}
/// Returns true if the device has a non-default index.
bool has_index() const noexcept {
return index_ != -1;
}
/// Return true if the device is of CUDA type.
bool is_cuda() const noexcept {
return type_ == DeviceType::CUDA;
}
/// Return true if the device is of MPS type.
bool is_mps() const noexcept {
return type_ == DeviceType::MPS;
}
/// Return true if the device is of HIP type.
bool is_hip() const noexcept {
return type_ == DeviceType::HIP;
}
/// Return true if the device is of VE type.
bool is_ve() const noexcept {
return type_ == DeviceType::VE;
}
/// Return true if the device is of XPU type.
bool is_xpu() const noexcept {
return type_ == DeviceType::XPU;
}
/// Return true if the device is of IPU type.
bool is_ipu() const noexcept {
return type_ == DeviceType::IPU;
}
/// Return true if the device is of XLA type.
bool is_xla() const noexcept {
return type_ == DeviceType::XLA;
}
/// Return true if the device is of HPU type.
bool is_hpu() const noexcept {
return type_ == DeviceType::HPU;
}
/// Return true if the device is of Lazy type.
bool is_lazy() const noexcept {
return type_ == DeviceType::Lazy;
}
/// Return true if the device is of Vulkan type.
bool is_vulkan() const noexcept {
return type_ == DeviceType::Vulkan;
}
/// Return true if the device is of Metal type.
bool is_metal() const noexcept {
return type_ == DeviceType::Metal;
}
/// Return true if the device is of ORT type.
bool is_ort() const noexcept {
return type_ == DeviceType::ORT;
}
/// Return true if the device is of META type.
bool is_meta() const noexcept {
return type_ == DeviceType::Meta;
}
/// Return true if the device is of CPU type.
bool is_cpu() const noexcept {
return type_ == DeviceType::CPU;
}
/// Return true if the device supports arbirtary strides.
bool supports_as_strided() const noexcept {
return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
type_ != DeviceType::Lazy;
}
/// Same string as returned from operator<<.
std::string str() const;
private:
DeviceType type_;
DeviceIndex index_ = -1;
void validate() {
// Removing these checks in release builds noticeably improves
// performance in micro-benchmarks.
// This is safe to do, because backends that use the DeviceIndex
// have a later check when we actually try to switch to that device.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
index_ == -1 || index_ >= 0,
"Device index must be -1 or non-negative, got ",
(int)index_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!is_cpu() || index_ <= 0,
"CPU device index must be -1 or zero, got ",
(int)index_);
}
};
C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
} // namespace c10
namespace std {
template <>
struct hash<c10::Device> {
size_t operator()(c10::Device d) const noexcept {
// Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly!
static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
// Note [Hazard when concatenating signed integers]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We must first convert to a same-sized unsigned type, before promoting to
// the result type, to prevent sign extension when any of the values is -1.
// If sign extension occurs, you'll clobber all of the values in the MSB
// half of the resulting integer.
//
// Technically, by C/C++ integer promotion rules, we only need one of the
// uint32_t casts to the result type, but we put in both for explicitness's
// sake.
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
<< 16 |
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
return std::hash<uint32_t>{}(bits);
}
};
} // namespace std
|