1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/alloc_cpu.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/extension.h>
#include <ATen/EmptyTensor.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/view.h>
#include <unordered_map>
static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;
static uint64_t storageImpl_counter = 0;
static uint64_t last_storageImpl_saved_value = 0;
// A dummy storageImpl for our custom device, that secretly uses the CPU
c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable) {
c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
if (data_ptr == nullptr){
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
} else {
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
}
storageImpl_counter += 1;
return custom_storage_impl;
}
// Register our dummy storageImpl create method.
void custom_storage_registry() {
c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
}
bool custom_storageImpl_called() {
if (storageImpl_counter > last_storageImpl_saved_value) {
last_storageImpl_saved_value = storageImpl_counter;
return true;
}
return false;
}
// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
add_counter += 1;
// Since this custom device is just for testing, not bothering to implement kernels.
return at::empty(self.sizes(), self.options());
}
at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
return dst.copy_(self, false);
}
// Some set operations for the basic use case
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
c10::IntArrayRef stride = {};
result.unsafeGetTensorImpl()->set_storage_offset(0);
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
new_size, stride_opt,
/*resize_storage=*/!result.is_meta());
return result;
}
// This macro does the heavy lifting.
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
// Later in this file, we map a custom device to the PrivateUse1 device type,
// which allows user code that puts a tensor on your custom_device to eventually get plumbed
// into the kernels registered here.
//
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
m.impl("set_.source_Storage", &custom_set_source_Storage);
}
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
at::native::cpu_fallback(op, stack);
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
// This basic implementation doesn't bother dealing with different device indices
// (e.g. custom_device:0 vs. custom_device:1).
// We could do that by letting the user pass in a device index in our exposed device function.
// Note that if you do that, you'll also need to register a device guard to core.
// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
c10::Device get_custom_device() {
return c10::Device(c10::DeviceType::PrivateUse1, 0);
}
bool custom_add_called() {
bool called = false;
if (add_counter > last_saved_value) {
called = true;
last_saved_value = add_counter;
}
return called;
}
void fallback_with_undefined_tensor() {
at::Tensor first = at::empty({2, 3}).to(at::DeviceType::PrivateUse1);
at::Tensor second = at::Tensor();
at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
at::TensorList tensors = {first, first};
at::TensorList undefined_tensors = {first, second};
at::TensorList steps = {step, step};
return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
grad_scale, found_inf);
}
// Here, we're exposing a custom device object that corresponds to our custom backend.
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
// that's implemented in C++.
// The implementation in this file maps directly to the `PrivateUse1` device type.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
}
|