1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
#pragma once
#include <array>
#include <cstdlib>
#include <ctime>
#include <memory>
#include <unordered_map>
#include <c10/macros/Macros.h>
#include <c10/core/Allocator.h>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
#include <c10/core/CopyBytes.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
class Event;
} // namespace caffe2
namespace at {
class BaseContext;
/**
* Virtual interface for the Context class in Caffe2.
*
* A Context defines all the necessities to run an operator on a specific
* device. Specific Context classes needs to implement all the pure virtual
* functions in the BaseContext class.
* TODO: add docs after this is finalized.
*/
class TORCH_API BaseContext {
public:
virtual ~BaseContext() noexcept {}
virtual Device device() const = 0;
/* Sorry for the naming, will get rid of this in future diff */
virtual DeviceType device_type() const = 0;
virtual void SwitchToDevice(int64_t /*stream_id*/) = 0;
inline void SwitchToDevice() {
SwitchToDevice(0);
}
virtual void WaitEvent(const caffe2::Event& ev) = 0;
virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr)
const = 0;
virtual void FinishDeviceComputation() = 0;
// This used to be arbitrary cross-device copy, but it turns out everyone
// did direct CPU-X copy, so we just make three functions for it (to avoid
// double dispatch). This will get obsoleted by C10. where copies
// will be proper operators (and get to rely on multiple dispatch there.)
virtual void CopyBytesSameDevice(
size_t nbytes,
const void* src,
void* dst) = 0;
virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0;
virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0;
template <typename T>
inline void CopySameDevice(size_t n, const T* src, T* dst) {
static_assert(
c10::guts::is_fundamental<T>::value,
"CopySameDevice requires fundamental types");
CopyBytesSameDevice(
n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
}
template <typename T>
inline void CopyFromCPU(size_t n, const T* src, T* dst) {
static_assert(
c10::guts::is_fundamental<T>::value,
"CopyFromCPU requires fundamental types");
CopyBytesFromCPU(
n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
}
template <typename T>
inline void CopyToCPU(size_t n, const T* src, T* dst) {
static_assert(
c10::guts::is_fundamental<T>::value, "CopyToCPU requires fundamental types");
CopyBytesToCPU(
n * sizeof(T), static_cast<const void*>(src), static_cast<void*>(dst));
}
virtual bool SupportsNonFundamentalTypes() const {
return false;
}
inline void EnforceMetaCopyOK() {
AT_ASSERTM(
SupportsNonFundamentalTypes(), "Context requires fundamental types");
}
void CopyItemsSameDevice(
const caffe2::TypeMeta meta,
size_t n,
const void* src,
void* dst) {
if (meta.copy()) {
EnforceMetaCopyOK();
meta.copy()(src, dst, n);
} else {
CopyBytesSameDevice(n * meta.itemsize(), src, dst);
}
}
void CopyItemsFromCPU(
const caffe2::TypeMeta meta,
size_t n,
const void* src,
void* dst) {
if (meta.copy()) {
EnforceMetaCopyOK();
meta.copy()(src, dst, n);
} else {
CopyBytesFromCPU(n * meta.itemsize(), src, dst);
}
}
void CopyItemsToCPU(
const caffe2::TypeMeta meta,
size_t n,
const void* src,
void* dst) {
if (meta.copy()) {
EnforceMetaCopyOK();
meta.copy()(src, dst, n);
} else {
CopyBytesToCPU(n * meta.itemsize(), src, dst);
}
}
};
// Context constructor registry
C10_DECLARE_TYPED_REGISTRY(
ContextRegistry,
at::DeviceType,
at::BaseContext,
std::unique_ptr,
at::Device);
#define REGISTER_CONTEXT(type, ...) \
C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)
inline std::unique_ptr<at::BaseContext> CreateContext(
const at::Device& device) {
return at::ContextRegistry()->Create(device.type(), device);
}
} // namespace at
namespace caffe2 {
using at::BaseContext;
using at::CreateContext;
} // namespace caffe2
|