1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
|
#pragma once
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
// Just for C10_ANONYMOUS_VARIABLE
#include <c10/util/Registry.h>
#include <atomic>
namespace c10 {
// Forward declaration
class DataPtr;
/**
* Flags defining the behavior of events.
*
* PYTORCH_DEFAULT and BACKEND_DEFAULT are valid for all backends. The
* BACKEND_DEFAULT is what a particular backend would select if no
* flags were given. PYTORCH_DEFAULT is the PyTorch's framework default
* choice for events on that backend, which may not be the same. For example,
* when PyTorch creates a CUDA event it sets the flag
* CUDA_EVENT_DISABLING_TIMING by default to improve performance.
*
* The mapping of PYTORCH_DEFAULT and BACKEND_DEFAULT is done by each
* backend implementation. Backend-specific flags, like CUDA_EVENT_DEFAULT,
* should map one-to-one with actual event flags for those backends.
*/
enum class EventFlag {
PYTORCH_DEFAULT,
BACKEND_DEFAULT,
// CUDA flags
CUDA_EVENT_DEFAULT,
CUDA_EVENT_DISABLE_TIMING, // PyTorch-default for CUDA
// HIP flags
HIP_EVENT_DEFAULT,
HIP_EVENT_DISABLE_TIMING, // PyTorch-default for HIP
// FOR TESTING ONLY
INVALID
};
namespace impl {
/**
* DeviceGuardImplInterface represents the virtual interface which provides
* functionality to provide an RAII class for device and stream switching,
* via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is
* expected to implement and register an implementation of this interface.
* All classes which inherit from DeviceGuardImplInterface should be declared
* 'final'.
*
* This class exists because we provide a unified interface for performing
* device guards via DeviceGuard, but we cannot assume that we have actually
* compiled against the, e.g., CUDA library, which actually implements
* this guard functionality. In this case, a dynamic dispatch is required
* to cross the library boundary.
*
* If possible, you should directly use implementations of this interface;
* those uses will be devirtualized.
*/
struct C10_API DeviceGuardImplInterface {
/**
* Return the type of device managed by this guard implementation.
*/
virtual DeviceType type() const = 0;
/**
* Set the current device to Device, and return the previous Device.
*/
virtual Device exchangeDevice(Device) const = 0;
// NB: Implementations of exchangeDevice can be a bit boilerplatey. You might
// consider replacing exchangeDevice with a non-virtual function with a baked
// in implementation; however, note that this will triple the number of
// virtual calls (when you implement exchangeDevice in a final subclass,
// the compiler gets to devirtualize everything; it won't do that if you don't
// define it in the subclass!) A common way to solve this problem is to use
// some sort of CRTP; however, we can template DeviceGuardImplInterface since
// we really *do* need it to be virtual. A little boilerplate seems easiest
// to explain. (Another way around this problem is to provide inline
// functions that provide the default implementations, but this seems a little
// hard to explain. In any case, we're only going to have on order of ten
// implementations of this anyway.)
/**
* Get the current device.
*/
virtual Device getDevice() const = 0;
/**
* Set the current device to Device.
*/
virtual void setDevice(Device) const = 0;
/**
* Set the current device to Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
virtual void uncheckedSetDevice(Device) const noexcept = 0;
/**
* Get the current stream for a given device.
*/
virtual Stream getStream(Device) const noexcept = 0;
/**
* Get the default stream for a given device.
*/
virtual Stream getDefaultStream(Device) const {
TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.")
}
/**
* Get a stream from the global pool for a given device.
*/
virtual Stream getStreamFromGlobalPool(Device, bool isHighPriority = false)
const {
(void)isHighPriority; // Suppress unused varaible warning
TORCH_CHECK(false, "Backend doesn't support acquiring a stream from pool.")
}
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
virtual Stream exchangeStream(Stream) const noexcept = 0;
/**
* Destroys the given event.
*/
virtual void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept {}
/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
virtual void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const c10::EventFlag /*flag*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
virtual void block(void* /*event*/, const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
virtual bool queryEvent(void* /*event*/) const {
TORCH_CHECK(false, "Backend doesn't support events.");
}
/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
virtual DeviceIndex deviceCount() const noexcept = 0;
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
virtual bool queryStream(const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support querying streams.");
}
/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const Stream& /*stream*/) const {
TORCH_CHECK(false, "Backend doesn't support synchronizing streams.");
}
/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
virtual void recordDataPtrOnStream(const c10::DataPtr&, const Stream&) const {
}
/**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
* So you better not call the destructor, buster!
*/
virtual ~DeviceGuardImplInterface() = default;
};
// A no-op device guard impl that doesn't do anything interesting. Useful
// for devices that don't actually have a concept of device index. Prominent
// examples are CPU and Meta.
template <DeviceType D>
struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
NoOpDeviceGuardImpl() {}
DeviceType type() const override {
return D;
}
Device exchangeDevice(Device) const override {
return Device(D, -1); // no-op
}
Device getDevice() const override {
return Device(D, -1);
}
void setDevice(Device) const override {
// no-op
}
void uncheckedSetDevice(Device) const noexcept override {
// no-op
}
Stream getStream(Device) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, -1));
}
// NB: These do NOT set the current device
Stream exchangeStream(Stream) const noexcept override {
// no-op
return Stream(Stream::DEFAULT, Device(D, -1));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Event-related functions
void record(
void** /*event*/,
const Stream& /*stream*/,
const DeviceIndex /*device_index*/,
const EventFlag /*flag*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.");
}
void block(void* /*event*/, const Stream& /*stream*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
bool queryEvent(void* /*event*/) const override {
TORCH_CHECK(false, D, " backend doesn't support events.")
}
void destroyEvent(void* /*event*/, const DeviceIndex /*device_index*/)
const noexcept override {}
// Stream-related functions
bool queryStream(const Stream& /*stream*/) const override {
return true;
}
void synchronizeStream(const Stream& /*stream*/) const override {
// Don't wait for anything.
}
};
// The registry is NON-owning. Each stored pointer is std::atomic so
// that under all interleavings of registry calls the structure is
// race-free. This doesn't cost us anything on reads in X86. (An
// unsynchronized implementation probably is OK too, but I didn't want
// to prove that we never read from device_guard_impl_registry at the
// same time some registration is occurring. Shiver.)
//
// I'd like this registry to be valid even at program destruction time
// (in case someone uses a DeviceGuard in a destructor to do some cleanup
// in the CUDA API.) Since there are no direct accesses of the underlying
// owning objects which I can use to enforce initialization order (unlike
// in a Meyer singleton), it implies that you must *leak* objects when
// putting them in the registry. This is done by deleting the destructor
// on DeviceGuardImplInterface.
extern C10_API std::atomic<const DeviceGuardImplInterface*>
device_guard_impl_registry[static_cast<size_t>(
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
// I can't conveniently use c10/util/Registry.h for the following reason:
// c10/util/Registry.h gives me a slow way of Create'ing a object of some
// interface from the registry, but no way of quickly accessing an already
// created object. I'll be banging on getDeviceGuardImpl every time we do a
// DeviceGuard, so I really don't want to be doing an unordered_map lookup.
// Better if the registration mechanism directly drops its implementation
// into device_guard_impl_registry.
class C10_API DeviceGuardImplRegistrar {
public:
DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
};
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE( \
g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
// Two adjacent int16_t fields DeviceType and DeviceIndex has field access
// miscompiled on NVCC. To workaround this issue, we apply a mask to the
// DeviceType. First check if the DeviceType is 16-bit.
// FB employees can see
// https://fb.workplace.com/groups/llvm.gcc/permalink/4053565044692080/
// for more details
static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
auto p = device_guard_impl_registry[static_cast<size_t>(type) & 0xFF].load();
// This seems to be the first place where you make use of a device
// when you pass devices to factory functions. Give a nicer error
// message in this case.
TORCH_CHECK(p, "PyTorch is not linked with support for ", type, " devices");
return p;
}
inline bool hasDeviceGuardImpl(DeviceType type) {
return device_guard_impl_registry[static_cast<size_t>(type)].load();
}
} // namespace impl
} // namespace c10
|