1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
#pragma once
#include <c10/core/Device.h>
namespace c10 {
/// An index representing a specific stream. A StreamId is not independently
/// meaningful without knowing the Device it is associated with; try to
/// use Stream rather than StreamId directly.
///
/// StreamIds are opaque; they are assigned by some DeviceType-specific
/// numbering system which is not visible to the user. HOWEVER, we
/// guarantee that StreamId 0 is always a valid stream, and corresponds
/// to some sort of "default" stream.
using StreamId = int64_t;
// NB: I decided not to call the above StreamIndex to avoid confusion with
// DeviceIndex. This way, you access device index with index(), and stream id
// with id()
/**
* A stream is a software mechanism used to synchronize launched kernels
* without requiring explicit synchronizations between kernels. The basic
* model is that every kernel launch is associated with a stream: every
* kernel on the same stream is implicitly synchronized so that if I launch
* kernels A and B on the same stream, A is guaranteed to finish before B
* launches. If I want B to run concurrently with A, I must schedule
* it on a different stream.
*
* The Stream class is a backend agnostic value class representing a stream
* which I may schedule a kernel on. Every stream is associated with a device,
* which is recorded in stream, which is used to avoid confusion about which
* device a stream refers to.
*
* Streams are explicitly thread-safe, in the sense that it is OK to pass
* a Stream from one thread to another, and kernels queued from two different
* threads will still get serialized appropriately. (Of course, the
* time when the kernels get queued is undetermined unless you synchronize
* host side ;)
*
* Stream does NOT have a default constructor. Streams are for expert
* users; if you want to use Streams, we're going to assume you know
* how to deal with C++ template error messages if you try to
* resize() a vector of Streams.
*
* Known instances of streams in backends:
*
* - cudaStream_t (CUDA)
* - hipStream_t (HIP)
* - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration
* does NOT support command queues.)
*
* Because this class is device agnostic, it cannot provide backend-specific
* functionality (e.g., get the cudaStream_t of a CUDA stream.) There are
* wrapper classes which provide this functionality, e.g., CUDAStream.
*/
class C10_API Stream final {
private:
Device device_;
StreamId id_;
public:
enum Unsafe { UNSAFE };
enum Default { DEFAULT };
/// Unsafely construct a stream from a Device and a StreamId. In
/// general, only specific implementations of streams for a
/// backend should manufacture Stream directly in this way; other users
/// should use the provided APIs to get a stream. In particular,
/// we don't require backends to give any guarantees about non-zero
/// StreamIds; they are welcome to allocate in whatever way they like.
explicit Stream(Unsafe, Device device, StreamId id)
: device_(device), id_(id) {}
/// Construct the default stream of a Device. The default stream is
/// NOT the same as the current stream; default stream is a fixed stream
/// that never changes, whereas the current stream may be changed by
/// StreamGuard.
explicit Stream(Default, Device device) : device_(device), id_(0) {}
bool operator==(const Stream& other) const noexcept {
return this->device_ == other.device_ && this->id_ == other.id_;
}
bool operator!=(const Stream& other) const noexcept {
return !(*this == other);
}
Device device() const noexcept {
return device_;
}
DeviceType device_type() const noexcept {
return device_.type();
}
DeviceIndex device_index() const noexcept {
return device_.index();
}
StreamId id() const noexcept {
return id_;
}
// Enqueues a wait instruction in the stream's work queue.
// This instruction is a no-op unless the event is marked
// for recording. In that case the stream stops processing
// until the event is recorded.
template <typename T>
void wait(const T& event) const {
event.block(*this);
}
// Return whether all asynchronous work previously enqueued on this stream
// has completed running on the device.
bool query() const;
// Wait (by blocking the calling thread) until all asynchronous work enqueued
// on this stream has completed running on the device.
void synchronize() const;
// The purpose of this function is to more conveniently permit binding
// of Stream to and from Python. Without packing, I have to setup a whole
// class with two fields (device and stream id); with packing I can just
// store a single uint64_t.
//
// The particular way we pack streams into a uint64_t is considered an
// implementation detail and should not be relied upon.
uint64_t pack() const noexcept {
// Are you here because this static assert failed? Make sure you ensure
// that the bitmasking code below is updated accordingly!
static_assert(sizeof(DeviceType) == 1, "DeviceType is not 8-bit");
static_assert(sizeof(DeviceIndex) == 1, "DeviceIndex is not 8-bit");
static_assert(sizeof(StreamId) == 8, "StreamId is not 64-bit");
// Concat these together into a 64-bit integer
// See Note [Hazard when concatenating signed integers]
uint64_t bits = static_cast<uint64_t>(static_cast<uint8_t>(device_type()))
<< 56 |
static_cast<uint64_t>(static_cast<uint8_t>(device_index())) << 48 |
// Remove the sign extension part of the 64-bit address because
// the id might be used to hold a pointer.
(static_cast<uint64_t>(id()) & ((1ull << 48) - 1));
TORCH_INTERNAL_ASSERT(
static_cast<DeviceIndex>((bits >> 48) & 0xFFull) == device_index(),
"DeviceIndex is not correctly packed");
TORCH_INTERNAL_ASSERT(
static_cast<DeviceType>((bits >> 56)) == device_type(),
"DeviceType is not correctly packed");
// Re-extend the sign of stream_id for checking
uint64_t mask = (1ull << 47);
TORCH_INTERNAL_ASSERT(
static_cast<StreamId>(((bits & 0xFFFFFFFFFFFFull) ^ mask) - mask) ==
id(),
"DeviceType is not correctly packed");
return bits;
}
static Stream unpack(uint64_t bits) {
// Re-extend the sign of stream_id
uint64_t mask = (1ull << 47);
const auto stream_id =
(static_cast<StreamId>(bits & 0xFFFFFFFFFFFFull) ^ mask) - mask;
bits >>= 48;
const auto device_index = static_cast<DeviceIndex>(bits & 0xFFull);
bits >>= 8;
const auto device_type = static_cast<DeviceType>(bits);
TORCH_CHECK(isValidDeviceType(device_type));
// Unfortunately, we can't check if the StreamId is valid here; it
// will be checked upon first use.
return Stream(UNSAFE, Device(device_type, device_index), stream_id);
}
// I decided NOT to provide setters on this class, because really,
// why would you change the device of a stream? Just construct
// it correctly from the beginning dude.
};
C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
} // namespace c10
namespace std {
template <>
struct hash<c10::Stream> {
size_t operator()(c10::Stream s) const noexcept {
return std::hash<uint64_t>{}(s.pack());
}
};
} // namespace std
|