1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
#pragma once
#include <c10/core/impl/InlineDeviceGuard.h>
namespace c10 {
/// RAII guard that sets a certain default device in its constructor, and
/// changes it back to the device that was originally active upon destruction.
///
/// The device is always reset to the one that was active at the time of
/// construction of the guard. Even if you `set_device` after construction, the
/// destructor will still reset the device to the one that was active at
/// construction time.
///
/// This device guard does NOT have an uninitialized state; it is guaranteed
/// to reset a device on exit. If you are in a situation where you *might*
/// want to setup a guard (i.e., are looking for the moral equivalent
/// of optional<DeviceGuard>), see OptionalDeviceGuard.
class DeviceGuard {
public:
/// No default constructor; see Note [Omitted default constructor from RAII]
explicit DeviceGuard() = delete;
/// Set the current device to the passed Device.
explicit DeviceGuard(Device device) : guard_(device) {}
/// This constructor is for testing only.
explicit DeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
///
/// TODO: The consistency check here is inconsistent with StreamGuard's
/// behavior with set_stream, where a stream on a different device than
/// the original one isn't an error; we just reset the stream and then
/// switch devices.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// This method is for testing only.
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Sets the device index to the given one. The device type is inferred
/// from the original device type the guard was constructed with.
void set_index(DeviceIndex index) {
guard_.set_index(index);
}
/// Returns the device that was set at the time the guard was constructed.
Device original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via set_device.
Device current_device() const {
return guard_.current_device();
}
private:
impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_;
};
/**
* A OptionalDeviceGuard is an RAII class that sets a device to some value on
* initialization, and resets the device to its original value on destruction.
* Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but
* with extra constructors and methods as appropriate.
*
* Besides its obvious use (optionally applying a DeviceGuard),
* OptionalDeviceGuard is often also used for the following idiom:
*
* OptionalDeviceGuard g;
* for (const auto& t : tensors) {
* g.set_device(t.device());
* do_something_with(t);
* }
*
* This usage is marginally more efficient than constructing a DeviceGuard every
* iteration of the for loop, as it avoids an unnecessary device reset.
*
* Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs
* when you use the nullary constructor, or pass a nullopt to the constructor.
* Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the
* original device was and they do not reset on destruction. This is why
* original_device() and current_device() return optional<Device> rather than
* Device (as they do in DeviceGuard), and also is why we didn't just
* provide OptionalDeviceGuard by default and hide DeviceGuard from users.
*
* The semantics of an OptionalDeviceGuard are exactly explained by thinking
* of it as an optional<DeviceGuard>. In particular, an initialized
* OptionalDeviceGuard doesn't restore device to its value at construction; it
* restores device to its value *at initialization*. So if you have the
* program:
*
* setDevice(1);
* OptionalDeviceGuard g;
* setDevice(2);
* g.reset_device(Device(DeviceType::CUDA, 3)); // initializes!
*
* On destruction, g will reset device to 2, rather than 1.
*
* An uninitialized OptionalDeviceGuard is distinct from a (initialized)
* DeviceGuard whose original_device_ and current_device_ match, since the
* DeviceGuard will still reset the device to original_device_.
*/
class OptionalDeviceGuard {
public:
/// Create an uninitialized guard. Set the guard later using reset_device.
explicit OptionalDeviceGuard() : guard_() {}
/// Initialize the guard, setting the current device to the passed Device.
explicit OptionalDeviceGuard(Device device) : guard_(device) {}
/// Initialize the guard if a Device is passed; otherwise leave the
/// guard uninitialized.
explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {}
/// Constructor for testing only.
explicit OptionalDeviceGuard(
Device device,
const impl::DeviceGuardImplInterface* impl)
: guard_(device, impl) {}
/// Copy is disallowed
OptionalDeviceGuard(const OptionalDeviceGuard&) = delete;
OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete;
/// Move is disallowed
/// See Note [Explicit initialization of optional fields]
/// and // Note [Move construction for RAII guards is tricky]
/// for rationale.
OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete;
OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete;
/// Sets the device to the given one. The specified device must be consistent
/// with the device type originally specified during guard construction.
void reset_device(at::Device device) {
guard_.reset_device(device);
}
/// For testing only
void reset_device(
at::Device device,
const impl::DeviceGuardImplInterface* impl) {
guard_.reset_device(device, impl);
}
/// Returns the device that was set at the time the guard was constructed.
optional<Device> original_device() const {
return guard_.original_device();
}
/// Returns the most recent device that was set using this device guard,
/// either from construction, or via reset_device.
optional<Device> current_device() const {
return guard_.current_device();
}
private:
impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_;
};
// Note [Whither the DeviceGuard boilerplate]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Design note: in principle, we could avoid these wrappers using:
//
// using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>;
// using OptionalDeviceGuard =
// impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>;
//
// But the error messages are worse, and our users can't just look at the
// header file to find out what's going on. Furthermore, for specializations
// like CUDAStreamGuard, it can be profitable to replace some interfaces with
// refined types (e.g., return CUDAStream instead of Stream). So, we eat
// the boilerplate and write out the API explicitly.
} // namespace c10
|