1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
|
#pragma once
// This file provides implementations of InlineDeviceGuard and
// InlineOptionalDeviceGuard.
#include <c10/core/Device.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/VirtualGuardImpl.h>
#include <c10/util/C++17.h>
#include <c10/util/Optional.h>
namespace c10 {
namespace impl {
/**
* A DeviceGuard is an RAII class that sets a device to some value
* on construction, and resets the device to its original value on
* destruction.
*
* InlineDeviceGuard is a helper class for implementing DeviceGuards.
* It is templated over a DeviceGuardImpl (anything that implements
* DeviceGuardImplInterface). There are two primary ways to instantiate
* InlineDeviceGuard:
*
* - With a concrete implementation of DeviceGuardImpl, e.g., CUDAGuardImpl.
* This is the best way to use InlineDeviceGuard, as all calls are
* devirtualized, giving you code as efficient as straight line
* calls to cudaGetDevice/cudaSetDevice.
*
* - With VirtualGuardImpl, which does a virtual dispatch to a DeviceGuardImpl
* retrieved from a DeviceType registry. We have explicitly instantiated
* InlineDeviceGuard this way as c10::DeviceGuard.
*
* If you are in a hurry, you can use InlineDeviceGuard directly:
*
* using CUDAGuard = impl::InlineDeviceGuard<CUDAGuardImpl>;
*
* However, you can provide a better user experience if you explicitly write a
* wrapper class that itself contains the template instantiation:
*
* class CUDAGuard {
* public:
* // ... the API ...
* private:
* impl::InlineDeviceGuard<CUDAGuardImpl> guard_;
* }
*
* The wrapper class provides a good place to write documentation, and helps
* avoid weird template instantiation errors when a user incorrectly uses the
* class.
*
* If you need to test this class, consider instantiating it with FakeGuardImpl.
*/
template <typename T>
class InlineDeviceGuard {
public:
// Note [Omitted default constructor from RAII]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In principle, we could add a default constructor to
// DeviceGuard which reads the current device and promises to
// restore to that device on exit. However, most cases where you
// would have written this, you probably meant to actually just
// use OptionalDeviceGuard (since you don't actually need the
// restore to happen if you don't ever actually set the device).
// We remove the constructor here to encourage you to think about
// what you actually want to happen.
explicit InlineDeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit InlineDeviceGuard(Device device)
: impl_(device.type()),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Set the current device index to the passed DeviceIndex. (The
/// device type is inferred from the template parameter T).
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineDeviceGuard(DeviceIndex device_index)
: InlineDeviceGuard(Device(U::static_type, device_index)) {}
/// Construct an InlineDeviceGuard using VirtualGuardImpl with an explicit
/// DeviceGuardImplInterface pointer.
template <
typename U = T,
typename = typename std::enable_if<
std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineDeviceGuard(
Device device,
const DeviceGuardImplInterface* impl)
: impl_(
VirtualGuardImpl(impl ? impl : getDeviceGuardImpl(device.type()))),
original_device_(
device.index() == -1 ? impl_.getDevice()
: impl_.exchangeDevice(device)),
current_device_(device.index() == -1 ? original_device_ : device) {}
/// Copy is disallowed
InlineDeviceGuard(const InlineDeviceGuard<T>&) = delete;
InlineDeviceGuard<T>& operator=(const InlineDeviceGuard<T>&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
InlineDeviceGuard(InlineDeviceGuard<T>&& other) = delete;
InlineDeviceGuard& operator=(InlineDeviceGuard<T>&& other) = delete;
~InlineDeviceGuard() {
impl_.uncheckedSetDevice(original_device_);
}
/// Sets the device to the given one.
template <
typename U = T,
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value, int>::
type = 0>
void set_device(at::Device device) {
AT_ASSERT(
(U::static_type == DeviceType::HIP && device.is_cuda()) ||
device.type() == U::static_type);
auto index = device.index();
if (index == -1)
return;
impl_.setDevice(device);
current_device_ = device;
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device. This is effectively equivalent to
/// set_device when a guard supports only a single device type.
template <typename U = T>
typename std::enable_if<!std::is_same<U, VirtualGuardImpl>::value>::type
reset_device(at::Device device) {
set_device(device);
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device (for a possibly different device
/// type).
///
/// This method is named reset_device to highlight the fact that previous
/// device settings from this guard are NOT preserved, even if the device
/// has a different device type. For example:
///
/// // CUDA device is 0
/// DeviceGuard g(Device(kCUDA, 1));
/// g.reset_device(Device(kHIP, 2));
/// // CUDA device is 0 (!!)
///
/// NOTE: this implementation may skip some device setting if it can prove
/// that it is unnecessary.
///
/// Optional argument is for testing only.
template <typename U = T>
typename std::enable_if<std::is_same<U, VirtualGuardImpl>::value>::type
reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl = nullptr) {
auto index = device.index();
if (index == -1)
return;
if (device.type() == original_device_.type()) {
AT_ASSERT(impl == nullptr || impl->type() == device.type());
impl_.setDevice(device);
current_device_ = device;
} else {
// Destruct and reconstruct the DeviceGuard in place
impl_.setDevice(original_device_);
impl_ = !impl ? VirtualGuardImpl(device.type()) : VirtualGuardImpl(impl);
original_device_ = impl_.exchangeDevice(device);
current_device_ = device;
}
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type.
void set_index(DeviceIndex index) {
reset_device(Device(original_device_.type(), index));
}
/// Returns the device that was set at the time the most recent
/// reset_device(), or otherwise the device at construction time.
Device original_device() const {
return original_device_;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device/reset_device/set_index.
Device current_device() const {
return current_device_;
}
protected:
T impl_;
private:
Device original_device_;
Device current_device_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
*
* InlineOptionalDeviceGuard is a helper class for implementing
* OptionalDeviceGuards. See guidance in InlineDeviceGuard on how to
* use this. See OptionalDeviceGuard for user-oriented usage notes.
*/
template <typename T>
class InlineOptionalDeviceGuard {
public:
// Note [Explicit initialization of optional fields]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Explicit initialization of optional fields
// required to workaround an nvcc bug; see
// https://github.com/pytorch/pytorch/issues/12117
/// Creates an uninitialized OptionalDeviceGuard.
explicit InlineOptionalDeviceGuard()
: guard_() // See Note [Explicit initialization of optional fields]
{}
/// Set the current device to the passed Device, if it is not nullopt.
explicit InlineOptionalDeviceGuard(optional<Device> device_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_opt.has_value()) {
guard_.emplace(device_opt.value());
}
}
/// Set the current device to the passed DeviceIndex, if it is not nullopt.
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
explicit InlineOptionalDeviceGuard(optional<DeviceIndex> device_index_opt)
: guard_() { // See Note [Explicit initialization of optional fields]
if (device_index_opt.has_value()) {
guard_.emplace(device_index_opt.value());
}
}
/// All constructors of DeviceGuard are valid for OptionalDeviceGuard
/// and result in initialized OptionalDeviceGuard.
template <typename... Args>
explicit InlineOptionalDeviceGuard(Args&&... args)
: guard_(in_place, std::forward<Args>(args)...) {}
// TODO: Consider readding Tensor and TensorList constructors here, when
// Tensor moves to c10. (These are only valid on OptionalDeviceGuard,
// because a Tensor may be undefined, in which case we need an uninitialized
// tensor guard.)
// Note [Move construction for RAII guards is tricky]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In principle, move construction is useful for terminating
// the lifetime of a `OptionalDeviceGuard` early; for example:
//
// // current device is d0
// OptionalDeviceGuard g1(d1);
// // current device is d1
// {
// OptionalDeviceGuard g2(std::move(g1));
// }
// // current device is d0!!
//
// However, it's difficult to implement the move constructor
// in a way that works in all situations. For example, consider
// the following example:
//
// OptionalDeviceGuard g1(d1);
// {
// OptionalDeviceGuard g2(d2);
// {
// OptionalDeviceGuard g3(std::move(g1)); // !!!
// }
// }
//
// What should the current device be while g3 in scope... and what
// should it be after it goes out of scope? What about g2?
// There don't seem to be satisfactory answers for these questions.
//
// It's in principle possible to raise an error when this occurs
// by doing some extra thread-local bookkeeping. But why bother?
// Just don't provide the constructor.
InlineOptionalDeviceGuard(InlineOptionalDeviceGuard<T>&& other) = delete;
// Note [Move assignment for RAII guards is tricky]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Move assignment is deleted, because you need to know which guard was
// defined "first", as that guard's original_device_ wins--with the current
// representation, we have no way of telling which is the case. (Move
// construction does not have this problem, as one guard is always
// uninitialized.)
//
// We can make this clear by way of a pair of examples:
//
// Example 1:
//
// // initial device is n0
// {
// CUDAGuard g1(n1);
// {
// CUDAGuard g2(n2);
// // current device should be n2
// g1 = std::move(g2);
// // current device should still be n2
// }
// // current device should still be n2
// }
// // current device should be n0
//
// Example 2 (flip the order of the two guards):
//
// // initial device is n0
// {
// CUDAGuard g2(n2);
// {
// CUDAGuard g1(n1);
// // current device should be n1
// g1 = std::move(g2);
// // current device should be n2
// }
// // current device should be n0 (since g2 has been vacated)
// }
//
// In both examples, we need g1 to restore to n0 after move assignment.
// However, in example 1, this is determined by the restore value of g1
// (prior to the move). In example 2, however, it is determined by the the
// restore value of g2(!!). We don't know which one should win, without having
// a way of telling which guard was allocated first.
//
// We could solve this with an extra thread-local variable. But no one is
// actually using move-assignment. So just get rid of it.
InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) =
delete;
/// Sets the device to the given one. Initializes OptionalDeviceGuard if it
/// is not already initialized.
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void set_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
} else {
guard_->set_device(device);
}
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device (for a possibly different device
/// type). Initializes OptionalDeviceGuard if it is not already initialized.
///
/// See notes on why this is called reset_device on InlineDeviceGuard.
///
/// Optional argument is for testing only.
template <
typename U = T,
typename = typename std::enable_if<
std::is_same<U, VirtualGuardImpl>::value>::type>
void reset_device(
at::Device device,
const DeviceGuardImplInterface* impl = nullptr) {
if (!guard_.has_value()) {
guard_.emplace(device, impl);
} else {
guard_->reset_device(device, impl);
}
}
/// Resets the currently set device to its original device, and then sets the
/// current device to the passed device. Initializes the guard if it is
/// not already initialized. This is effectively equivalent to set_device
/// when a guard supports only a single device type.
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void reset_device(at::Device device) {
if (!guard_.has_value()) {
guard_.emplace(device);
} else {
guard_->reset_device(device);
}
}
/// Sets the device index to the given one. The device type is statically
/// known.
template <
typename U = T,
typename = typename std::enable_if<
!std::is_same<U, VirtualGuardImpl>::value>::type>
void set_index(DeviceIndex index) {
if (!guard_.has_value()) {
guard_.emplace(index);
} else {
guard_->set_index(index);
}
}
/// Returns the device that was set immediately prior to initialization of
/// the, guard, or nullopt if the guard is uninitialized.
optional<Device> original_device() const {
return guard_.has_value() ? make_optional(guard_->original_device())
: nullopt;
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device, if the guard is initialized,
/// or nullopt if the guard is uninitialized.
optional<Device> current_device() const {
return guard_.has_value() ? make_optional(guard_->current_device())
: nullopt;
}
/// Restore the original device, resetting this guard to uninitialized state.
void reset() {
guard_.reset();
}
private:
optional<InlineDeviceGuard<T>> guard_;
};
} // namespace impl
} // namespace c10
|